Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
import bz2 import bz2
import gzip
import hashlib
import itertools
import lzma
import os import os
import os.path import os.path
import hashlib import pathlib
import gzip
import re import re
import tarfile import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import zipfile
import lzma
import urllib import urllib
import urllib.request
import urllib.error import urllib.error
import pathlib import urllib.request
import itertools import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import torch import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
...@@ -52,8 +52,8 @@ def gen_bar_updater() -> Callable[[int, int, int], None]: ...@@ -52,8 +52,8 @@ def gen_bar_updater() -> Callable[[int, int, int], None]:
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
md5 = hashlib.md5() md5 = hashlib.md5()
with open(fpath, 'rb') as f: with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b''): for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk) md5.update(chunk)
return md5.hexdigest() return md5.hexdigest()
...@@ -120,7 +120,7 @@ def download_url( ...@@ -120,7 +120,7 @@ def download_url(
# check if file is already present locally # check if file is already present locally
if check_integrity(fpath, md5): if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath) print("Using downloaded and verified file: " + fpath)
return return
if _is_remote_location_available(): if _is_remote_location_available():
...@@ -136,13 +136,12 @@ def download_url( ...@@ -136,13 +136,12 @@ def download_url(
# download the file # download the file
try: try:
print('Downloading ' + url + ' to ' + fpath) print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath) _urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https': if url[:5] == "https":
url = url.replace('https:', 'http:') url = url.replace("https:", "http:")
print('Failed download. Trying https -> http instead.' print("Failed download. Trying https -> http instead." " Downloading " + url + " to " + fpath)
' Downloading ' + url + ' to ' + fpath)
_urlretrieve(url, fpath) _urlretrieve(url, fpath)
else: else:
raise e raise e
...@@ -202,6 +201,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -202,6 +201,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
""" """
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests import requests
url = "https://docs.google.com/uc?export=download" url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root) root = os.path.expanduser(root)
...@@ -212,15 +212,15 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -212,15 +212,15 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and check_integrity(fpath, md5): if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath) print("Using downloaded and verified file: " + fpath)
else: else:
session = requests.Session() session = requests.Session()
response = session.get(url, params={'id': file_id}, stream=True) response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response) token = _get_confirm_token(response)
if token: if token:
params = {'id': file_id, 'confirm': token} params = {"id": file_id, "confirm": token}
response = session.get(url, params=params, stream=True) response = session.get(url, params=params, stream=True)
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent # Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
...@@ -240,20 +240,21 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ ...@@ -240,20 +240,21 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
) )
raise RuntimeError(msg) raise RuntimeError(msg)
_save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath) _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close() response.close()
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined] def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
for key, value in response.cookies.items(): for key, value in response.cookies.items():
if key.startswith('download_warning'): if key.startswith("download_warning"):
return value return value
return None return None
def _save_response_content( def _save_response_content(
response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined] response_gen: Iterator[bytes],
destination: str, # type: ignore[name-defined]
) -> None: ) -> None:
with open(destination, "wb") as f: with open(destination, "wb") as f:
pbar = tqdm(total=None) pbar = tqdm(total=None)
...@@ -439,7 +440,10 @@ T = TypeVar("T", str, bytes) ...@@ -439,7 +440,10 @@ T = TypeVar("T", str, bytes)
def verify_str_arg( def verify_str_arg(
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None, value: T,
arg: Optional[str] = None,
valid_values: Iterable[T] = None,
custom_msg: Optional[str] = None,
) -> T: ) -> T:
if not isinstance(value, torch._six.string_classes): if not isinstance(value, torch._six.string_classes):
if arg is None: if arg is None:
...@@ -456,10 +460,8 @@ def verify_str_arg( ...@@ -456,10 +460,8 @@ def verify_str_arg(
if custom_msg is not None: if custom_msg is not None:
msg = custom_msg msg = custom_msg
else: else:
msg = ("Unknown value '{value}' for argument {arg}. " msg = "Unknown value '{value}' for argument {arg}. " "Valid values are {{{valid_values}}}."
"Valid values are {{{valid_values}}}.") msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
msg = msg.format(value=value, arg=arg,
valid_values=iterable_to_str(valid_values))
raise ValueError(msg) raise ValueError(msg)
return value return value
...@@ -206,14 +206,14 @@ class VideoClips(object): ...@@ -206,14 +206,14 @@ class VideoClips(object):
if frame_rate is None: if frame_rate is None:
frame_rate = fps frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps) total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx( idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
int(math.floor(total_frames)), fps, frame_rate
)
video_pts = video_pts[idxs] video_pts = video_pts[idxs]
clips = unfold(video_pts, num_frames, step) clips = unfold(video_pts, num_frames, step)
if not clips.numel(): if not clips.numel():
warnings.warn("There aren't enough frames in the current video to get a clip for the given clip length and " warnings.warn(
"frames between clips. The video (and potentially others) will be skipped.") "There aren't enough frames in the current video to get a clip for the given clip length and "
"frames between clips. The video (and potentially others) will be skipped."
)
if isinstance(idxs, slice): if isinstance(idxs, slice):
idxs = [idxs] * len(clips) idxs = [idxs] * len(clips)
else: else:
...@@ -237,9 +237,7 @@ class VideoClips(object): ...@@ -237,9 +237,7 @@ class VideoClips(object):
self.clips = [] self.clips = []
self.resampling_idxs = [] self.resampling_idxs = []
for video_pts, fps in zip(self.video_pts, self.video_fps): for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video( clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
video_pts, num_frames, step, fps, frame_rate
)
self.clips.append(clips) self.clips.append(clips)
self.resampling_idxs.append(idxs) self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips]) clip_lengths = torch.as_tensor([len(v) for v in self.clips])
...@@ -295,10 +293,7 @@ class VideoClips(object): ...@@ -295,10 +293,7 @@ class VideoClips(object):
video_idx (int): index of the video in `video_paths` video_idx (int): index of the video in `video_paths`
""" """
if idx >= self.num_clips(): if idx >= self.num_clips():
raise IndexError( raise IndexError("Index {} out of range " "({} number of clips)".format(idx, self.num_clips()))
"Index {} out of range "
"({} number of clips)".format(idx, self.num_clips())
)
video_idx, clip_idx = self.get_clip_location(idx) video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx] video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx] clip_pts = self.clips[video_idx][clip_idx]
...@@ -314,13 +309,9 @@ class VideoClips(object): ...@@ -314,13 +309,9 @@ class VideoClips(object):
if self._video_height != 0: if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0") raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0: if self._video_min_dimension != 0:
raise ValueError( raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
"pyav backend doesn't support _video_min_dimension != 0"
)
if self._video_max_dimension != 0: if self._video_max_dimension != 0:
raise ValueError( raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
"pyav backend doesn't support _video_max_dimension != 0"
)
if self._audio_samples != 0: if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0") raise ValueError("pyav backend doesn't support _audio_samples != 0")
...@@ -338,19 +329,11 @@ class VideoClips(object): ...@@ -338,19 +329,11 @@ class VideoClips(object):
audio_start_pts, audio_end_pts = 0, -1 audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1) audio_timebase = Fraction(0, 1)
video_timebase = Fraction( video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
info.video_timebase.numerator, info.video_timebase.denominator
)
if info.has_audio: if info.has_audio:
audio_timebase = Fraction( audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
info.audio_timebase.numerator, info.audio_timebase.denominator audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
) audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
audio_start_pts = pts_convert(
video_start_pts, video_timebase, audio_timebase, math.floor
)
audio_end_pts = pts_convert(
video_end_pts, video_timebase, audio_timebase, math.ceil
)
audio_fps = info.audio_sample_rate audio_fps = info.audio_sample_rate
video, audio, info = _read_video_from_file( video, audio, info = _read_video_from_file(
video_path, video_path,
...@@ -376,9 +359,7 @@ class VideoClips(object): ...@@ -376,9 +359,7 @@ class VideoClips(object):
resampling_idx = resampling_idx - resampling_idx[0] resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx] video = video[resampling_idx]
info["video_fps"] = self.frame_rate info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, "{} x {}".format( assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames)
video.shape, self.num_frames
)
return video, audio, info, video_idx return video, audio, info, video_idx
def __getstate__(self): def __getstate__(self):
......
import os import os
from typing import Any, Callable, List, Optional, Tuple
import torch import torch
import torch.utils.data as data import torch.utils.data as data
from typing import Any, Callable, List, Optional, Tuple
class VisionDataset(data.Dataset): class VisionDataset(data.Dataset):
...@@ -22,14 +23,15 @@ class VisionDataset(data.Dataset): ...@@ -22,14 +23,15 @@ class VisionDataset(data.Dataset):
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
""" """
_repr_indent = 4 _repr_indent = 4
def __init__( def __init__(
self, self,
root: str, root: str,
transforms: Optional[Callable] = None, transforms: Optional[Callable] = None,
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}") torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}")
if isinstance(root, torch._six.string_classes): if isinstance(root, torch._six.string_classes):
...@@ -39,8 +41,7 @@ class VisionDataset(data.Dataset): ...@@ -39,8 +41,7 @@ class VisionDataset(data.Dataset):
has_transforms = transforms is not None has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform: if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can " raise ValueError("Only transforms or transform/target_transform can " "be passed as argument")
"be passed as argument")
# for backwards-compatibility # for backwards-compatibility
self.transform = transform self.transform = transform
...@@ -72,12 +73,11 @@ class VisionDataset(data.Dataset): ...@@ -72,12 +73,11 @@ class VisionDataset(data.Dataset):
if hasattr(self, "transforms") and self.transforms is not None: if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)] body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body] lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines) return "\n".join(lines)
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines() lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] + return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self) -> str: def extra_repr(self) -> str:
return "" return ""
...@@ -97,16 +97,13 @@ class StandardTransform(object): ...@@ -97,16 +97,13 @@ class StandardTransform(object):
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines() lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] + return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def __repr__(self) -> str: def __repr__(self) -> str:
body = [self.__class__.__name__] body = [self.__class__.__name__]
if self.transform is not None: if self.transform is not None:
body += self._format_transform_repr(self.transform, body += self._format_transform_repr(self.transform, "Transform: ")
"Transform: ")
if self.target_transform is not None: if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform, body += self._format_transform_repr(self.target_transform, "Target transform: ")
"Target transform: ")
return '\n'.join(body) return "\n".join(body)
import os
import collections import collections
from .vision import VisionDataset import os
from xml.etree.ElementTree import Element as ET_Element from xml.etree.ElementTree import Element as ET_Element
from .vision import VisionDataset
try: try:
from defusedxml.ElementTree import parse as ET_parse from defusedxml.ElementTree import parse as ET_parse
except ImportError: except ImportError:
from xml.etree.ElementTree import parse as ET_parse from xml.etree.ElementTree import parse as ET_parse
from PIL import Image import warnings
from typing import Any, Callable, Dict, Optional, Tuple, List from typing import Any, Callable, Dict, Optional, Tuple, List
from PIL import Image
from .utils import download_and_extract_archive, verify_str_arg from .utils import download_and_extract_archive, verify_str_arg
import warnings
DATASET_YEAR_DICT = { DATASET_YEAR_DICT = {
'2012': { "2012": {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
'filename': 'VOCtrainval_11-May-2012.tar', "filename": "VOCtrainval_11-May-2012.tar",
'md5': '6cd6e144f989b92b3379bac3b3de84fd', "md5": "6cd6e144f989b92b3379bac3b3de84fd",
'base_dir': os.path.join('VOCdevkit', 'VOC2012') "base_dir": os.path.join("VOCdevkit", "VOC2012"),
},
"2011": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
"filename": "VOCtrainval_25-May-2011.tar",
"md5": "6c3384ef61512963050cb5d687e5bf1e",
"base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
}, },
'2011': { "2010": {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
'filename': 'VOCtrainval_25-May-2011.tar', "filename": "VOCtrainval_03-May-2010.tar",
'md5': '6c3384ef61512963050cb5d687e5bf1e', "md5": "da459979d0c395079b5c75ee67908abb",
'base_dir': os.path.join('TrainVal', 'VOCdevkit', 'VOC2011') "base_dir": os.path.join("VOCdevkit", "VOC2010"),
}, },
'2010': { "2009": {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
'filename': 'VOCtrainval_03-May-2010.tar', "filename": "VOCtrainval_11-May-2009.tar",
'md5': 'da459979d0c395079b5c75ee67908abb', "md5": "59065e4b188729180974ef6572f6a212",
'base_dir': os.path.join('VOCdevkit', 'VOC2010') "base_dir": os.path.join("VOCdevkit", "VOC2009"),
}, },
'2009': { "2008": {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
'filename': 'VOCtrainval_11-May-2009.tar', "filename": "VOCtrainval_11-May-2012.tar",
'md5': '59065e4b188729180974ef6572f6a212', "md5": "2629fa636546599198acfcfbfcf1904a",
'base_dir': os.path.join('VOCdevkit', 'VOC2009') "base_dir": os.path.join("VOCdevkit", "VOC2008"),
}, },
'2008': { "2007": {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
'filename': 'VOCtrainval_11-May-2012.tar', "filename": "VOCtrainval_06-Nov-2007.tar",
'md5': '2629fa636546599198acfcfbfcf1904a', "md5": "c52e279531787c972589f7e41ab4ae64",
'base_dir': os.path.join('VOCdevkit', 'VOC2008') "base_dir": os.path.join("VOCdevkit", "VOC2007"),
}, },
'2007': { "2007-test": {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', "url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
'filename': 'VOCtrainval_06-Nov-2007.tar', "filename": "VOCtest_06-Nov-2007.tar",
'md5': 'c52e279531787c972589f7e41ab4ae64', "md5": "b6e924de25625d8de591ea690078ad9f",
'base_dir': os.path.join('VOCdevkit', 'VOC2007') "base_dir": os.path.join("VOCdevkit", "VOC2007"),
}, },
'2007-test': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
'filename': 'VOCtest_06-Nov-2007.tar',
'md5': 'b6e924de25625d8de591ea690078ad9f',
'base_dir': os.path.join('VOCdevkit', 'VOC2007')
}
} }
......
from PIL import Image
import os import os
from os.path import abspath, expanduser from os.path import abspath, expanduser
import torch
from typing import Any, Callable, List, Dict, Optional, Tuple, Union from typing import Any, Callable, List, Dict, Optional, Tuple, Union
from .utils import check_integrity, download_file_from_google_drive, \
download_and_extract_archive, extract_archive, verify_str_arg import torch
from PIL import Image
from .utils import (
check_integrity,
download_file_from_google_drive,
download_and_extract_archive,
extract_archive,
verify_str_arg,
)
from .vision import VisionDataset from .vision import VisionDataset
...@@ -40,25 +47,25 @@ class WIDERFace(VisionDataset): ...@@ -40,25 +47,25 @@ class WIDERFace(VisionDataset):
# File ID MD5 Hash Filename # File ID MD5 Hash Filename
("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"), ("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"), ("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip") ("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
] ]
ANNOTATIONS_FILE = ( ANNOTATIONS_FILE = (
"http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip", "http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip",
"0e3767bcf0e326556d407bf5bff5d27c", "0e3767bcf0e326556d407bf5bff5d27c",
"wider_face_split.zip" "wider_face_split.zip",
) )
def __init__( def __init__(
self, self,
root: str, root: str,
split: str = "train", split: str = "train",
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
download: bool = False, download: bool = False,
) -> None: ) -> None:
super(WIDERFace, self).__init__(root=os.path.join(root, self.BASE_FOLDER), super(WIDERFace, self).__init__(
transform=transform, root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
target_transform=target_transform) )
# check arguments # check arguments
self.split = verify_str_arg(split, "split", ("train", "val", "test")) self.split = verify_str_arg(split, "split", ("train", "val", "test"))
...@@ -66,8 +73,9 @@ class WIDERFace(VisionDataset): ...@@ -66,8 +73,9 @@ class WIDERFace(VisionDataset):
self.download() self.download()
if not self._check_integrity(): if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. " + raise RuntimeError(
"You can use download=True to download and prepare it") "Dataset not found or corrupted. " + "You can use download=True to download and prepare it"
)
self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = [] self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
if self.split in ("train", "val"): if self.split in ("train", "val"):
...@@ -102,7 +110,7 @@ class WIDERFace(VisionDataset): ...@@ -102,7 +110,7 @@ class WIDERFace(VisionDataset):
def extra_repr(self) -> str: def extra_repr(self) -> str:
lines = ["Split: {split}"] lines = ["Split: {split}"]
return '\n'.join(lines).format(**self.__dict__) return "\n".join(lines).format(**self.__dict__)
def parse_train_val_annotations_file(self) -> None: def parse_train_val_annotations_file(self) -> None:
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt" filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
...@@ -133,16 +141,20 @@ class WIDERFace(VisionDataset): ...@@ -133,16 +141,20 @@ class WIDERFace(VisionDataset):
box_annotation_line = False box_annotation_line = False
file_name_line = True file_name_line = True
labels_tensor = torch.tensor(labels) labels_tensor = torch.tensor(labels)
self.img_info.append({ self.img_info.append(
"img_path": img_path, {
"annotations": {"bbox": labels_tensor[:, 0:4], # x, y, width, height "img_path": img_path,
"blur": labels_tensor[:, 4], "annotations": {
"expression": labels_tensor[:, 5], "bbox": labels_tensor[:, 0:4], # x, y, width, height
"illumination": labels_tensor[:, 6], "blur": labels_tensor[:, 4],
"occlusion": labels_tensor[:, 7], "expression": labels_tensor[:, 5],
"pose": labels_tensor[:, 8], "illumination": labels_tensor[:, 6],
"invalid": labels_tensor[:, 9]} "occlusion": labels_tensor[:, 7],
}) "pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9],
},
}
)
box_counter = 0 box_counter = 0
labels.clear() labels.clear()
else: else:
...@@ -172,7 +184,7 @@ class WIDERFace(VisionDataset): ...@@ -172,7 +184,7 @@ class WIDERFace(VisionDataset):
def download(self) -> None: def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print("Files already downloaded and verified")
return return
# download and extract image data # download and extract image data
...@@ -182,6 +194,6 @@ class WIDERFace(VisionDataset): ...@@ -182,6 +194,6 @@ class WIDERFace(VisionDataset):
extract_archive(filepath) extract_archive(filepath)
# download and extract annotation files # download and extract annotation files
download_and_extract_archive(url=self.ANNOTATIONS_FILE[0], download_and_extract_archive(
download_root=self.root, url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
md5=self.ANNOTATIONS_FILE[1]) )
...@@ -11,12 +11,14 @@ def _has_ops(): ...@@ -11,12 +11,14 @@ def _has_ops():
try: try:
lib_path = _get_extension_path('_C') lib_path = _get_extension_path("_C")
torch.ops.load_library(lib_path) torch.ops.load_library(lib_path)
_HAS_OPS = True _HAS_OPS = True
def _has_ops(): # noqa: F811 def _has_ops(): # noqa: F811
return True return True
except (ImportError, OSError): except (ImportError, OSError):
pass pass
...@@ -41,6 +43,7 @@ def _check_cuda_version(): ...@@ -41,6 +43,7 @@ def _check_cuda_version():
if not _HAS_OPS: if not _HAS_OPS:
return -1 return -1
import torch import torch
_version = torch.ops.torchvision._cuda_version() _version = torch.ops.torchvision._cuda_version()
if _version != -1 and torch.version.cuda is not None: if _version != -1 and torch.version.cuda is not None:
tv_version = str(_version) tv_version = str(_version)
...@@ -51,14 +54,17 @@ def _check_cuda_version(): ...@@ -51,14 +54,17 @@ def _check_cuda_version():
tv_major = int(tv_version[0:2]) tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3]) tv_minor = int(tv_version[3])
t_version = torch.version.cuda t_version = torch.version.cuda
t_version = t_version.split('.') t_version = t_version.split(".")
t_major = int(t_version[0]) t_major = int(t_version[0])
t_minor = int(t_version[1]) t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor: if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. " raise RuntimeError(
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. " "Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"Please reinstall the torchvision that matches your PyTorch install." "PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
.format(t_major, t_minor, tv_major, tv_minor)) "Please reinstall the torchvision that matches your PyTorch install.".format(
t_major, t_minor, tv_major, tv_minor
)
)
return _version return _version
......
import torch
from typing import Any, Dict, Iterator from typing import Any, Dict, Iterator
import torch
from ._video_opt import ( from ._video_opt import (
Timebase, Timebase,
VideoMetaData, VideoMetaData,
...@@ -12,11 +13,6 @@ from ._video_opt import ( ...@@ -12,11 +13,6 @@ from ._video_opt import (
_read_video_timestamps_from_file, _read_video_timestamps_from_file,
_read_video_timestamps_from_memory, _read_video_timestamps_from_memory,
) )
from .video import (
read_video,
read_video_timestamps,
write_video,
)
from .image import ( from .image import (
ImageReadMode, ImageReadMode,
decode_image, decode_image,
...@@ -30,6 +26,11 @@ from .image import ( ...@@ -30,6 +26,11 @@ from .image import (
write_jpeg, write_jpeg,
write_png, write_png,
) )
from .video import (
read_video,
read_video_timestamps,
write_video,
)
if _HAS_VIDEO_OPT: if _HAS_VIDEO_OPT:
...@@ -127,10 +128,10 @@ class VideoReader: ...@@ -127,10 +128,10 @@ class VideoReader:
raise StopIteration raise StopIteration
return {"data": frame, "pts": pts} return {"data": frame, "pts": pts}
def __iter__(self) -> Iterator['VideoReader']: def __iter__(self) -> Iterator["VideoReader"]:
return self return self
def seek(self, time_s: float) -> 'VideoReader': def seek(self, time_s: float) -> "VideoReader":
"""Seek within current stream. """Seek within current stream.
Args: Args:
......
import math import math
import os import os
import warnings import warnings
...@@ -12,7 +11,7 @@ from .._internally_replaced_utils import _get_extension_path ...@@ -12,7 +11,7 @@ from .._internally_replaced_utils import _get_extension_path
try: try:
lib_path = _get_extension_path('video_reader') lib_path = _get_extension_path("video_reader")
torch.ops.load_library(lib_path) torch.ops.load_library(lib_path)
_HAS_VIDEO_OPT = True _HAS_VIDEO_OPT = True
except (ImportError, OSError): except (ImportError, OSError):
...@@ -90,9 +89,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): ...@@ -90,9 +89,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
""" """
meta = VideoMetaData() meta = VideoMetaData()
if vtimebase.numel() > 0: if vtimebase.numel() > 0:
meta.video_timebase = Timebase( meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
int(vtimebase[0].item()), int(vtimebase[1].item())
)
timebase = vtimebase[0].item() / float(vtimebase[1].item()) timebase = vtimebase[0].item() / float(vtimebase[1].item())
if vduration.numel() > 0: if vduration.numel() > 0:
meta.has_video = True meta.has_video = True
...@@ -100,9 +97,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration): ...@@ -100,9 +97,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
if vfps.numel() > 0: if vfps.numel() > 0:
meta.video_fps = float(vfps.item()) meta.video_fps = float(vfps.item())
if atimebase.numel() > 0: if atimebase.numel() > 0:
meta.audio_timebase = Timebase( meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
int(atimebase[0].item()), int(atimebase[1].item())
)
timebase = atimebase[0].item() / float(atimebase[1].item()) timebase = atimebase[0].item() / float(atimebase[1].item())
if aduration.numel() > 0: if aduration.numel() > 0:
meta.has_audio = True meta.has_audio = True
...@@ -216,10 +211,7 @@ def _read_video_from_file( ...@@ -216,10 +211,7 @@ def _read_video_from_file(
audio_timebase.numerator, audio_timebase.numerator,
audio_timebase.denominator, audio_timebase.denominator,
) )
vframes, _vframe_pts, vtimebase, vfps, vduration, \ vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0: if aframes.numel() > 0:
# when audio stream is found # when audio stream is found
...@@ -254,8 +246,7 @@ def _read_video_timestamps_from_file(filename): ...@@ -254,8 +246,7 @@ def _read_video_timestamps_from_file(filename):
0, # audio_timebase_num 0, # audio_timebase_num
1, # audio_timebase_den 1, # audio_timebase_den
) )
_vframes, vframe_pts, vtimebase, vfps, vduration, \ _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
_aframes, aframe_pts, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist() vframe_pts = vframe_pts.numpy().tolist()
...@@ -372,10 +363,7 @@ def _read_video_from_memory( ...@@ -372,10 +363,7 @@ def _read_video_from_memory(
audio_timebase_denominator, audio_timebase_denominator,
) )
vframes, _vframe_pts, vtimebase, vfps, vduration, \ vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
if aframes.numel() > 0: if aframes.numel() > 0:
# when audio stream is found # when audio stream is found
...@@ -413,10 +401,7 @@ def _read_video_timestamps_from_memory(video_data): ...@@ -413,10 +401,7 @@ def _read_video_timestamps_from_memory(video_data):
0, # audio_timebase_num 0, # audio_timebase_num
1, # audio_timebase_den 1, # audio_timebase_den
) )
_vframes, vframe_pts, vtimebase, vfps, vduration, \ _vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
_aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist() vframe_pts = vframe_pts.numpy().tolist()
...@@ -439,10 +424,10 @@ def _probe_video_from_memory(video_data): ...@@ -439,10 +424,10 @@ def _probe_video_from_memory(video_data):
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base): def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
if pts_unit == 'pts': if pts_unit == "pts":
start_pts = float(start_pts * time_base) start_pts = float(start_pts * time_base)
end_pts = float(end_pts * time_base) end_pts = float(end_pts * time_base)
pts_unit = 'sec' pts_unit = "sec"
return start_pts, end_pts, pts_unit return start_pts, end_pts, pts_unit
...@@ -467,20 +452,15 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"): ...@@ -467,20 +452,15 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
time_base = default_timebase time_base = default_timebase
if has_video: if has_video:
video_timebase = Fraction( video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
info.video_timebase.numerator, info.video_timebase.denominator
)
time_base = video_timebase time_base = video_timebase
if has_audio: if has_audio:
audio_timebase = Fraction( audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
info.audio_timebase.numerator, info.audio_timebase.denominator
)
time_base = time_base if time_base else audio_timebase time_base = time_base if time_base else audio_timebase
# video_timebase is the default time_base # video_timebase is the default time_base
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec( start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base)
start_pts, end_pts, pts_unit, time_base)
def get_pts(time_base): def get_pts(time_base):
start_offset = start_pts_sec start_offset = start_pts_sec
...@@ -527,9 +507,7 @@ def _read_video_timestamps(filename, pts_unit="pts"): ...@@ -527,9 +507,7 @@ def _read_video_timestamps(filename, pts_unit="pts"):
pts, _, info = _read_video_timestamps_from_file(filename) pts, _, info = _read_video_timestamps_from_file(filename)
if pts_unit == "sec": if pts_unit == "sec":
video_time_base = Fraction( video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
info.video_timebase.numerator, info.video_timebase.denominator
)
pts = [x * video_time_base for x in pts] pts = [x * video_time_base for x in pts]
video_fps = info.video_fps if info.has_video else None video_fps = info.video_fps if info.has_video else None
......
import torch
from enum import Enum from enum import Enum
import torch
from .._internally_replaced_utils import _get_extension_path from .._internally_replaced_utils import _get_extension_path
try: try:
lib_path = _get_extension_path('image') lib_path = _get_extension_path("image")
torch.ops.load_library(lib_path) torch.ops.load_library(lib_path)
except (ImportError, OSError): except (ImportError, OSError):
pass pass
...@@ -21,6 +22,7 @@ class ImageReadMode(Enum): ...@@ -21,6 +22,7 @@ class ImageReadMode(Enum):
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for ``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency. RGB with transparency.
""" """
UNCHANGED = 0 UNCHANGED = 0
GRAY = 1 GRAY = 1
GRAY_ALPHA = 2 GRAY_ALPHA = 2
...@@ -111,8 +113,9 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): ...@@ -111,8 +113,9 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output) write_file(filename, output)
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, def decode_jpeg(
device: str = 'cpu') -> torch.Tensor: input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor:
""" """
Decodes a JPEG image into a 3 dimensional RGB Tensor. Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
...@@ -135,7 +138,7 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG ...@@ -135,7 +138,7 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG
output (Tensor[image_channels, image_height, image_width]) output (Tensor[image_channels, image_height, image_width])
""" """
device = torch.device(device) device = torch.device(device)
if device.type == 'cuda': if device.type == "cuda":
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else: else:
output = torch.ops.image.decode_jpeg(input, mode.value) output = torch.ops.image.decode_jpeg(input, mode.value)
...@@ -158,8 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor: ...@@ -158,8 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
JPEG file. JPEG file.
""" """
if quality < 1 or quality > 100: if quality < 1 or quality > 100:
raise ValueError('Image quality should be a positive number ' raise ValueError("Image quality should be a positive number " "between 1 and 100")
'between 1 and 100')
output = torch.ops.image.encode_jpeg(input, quality) output = torch.ops.image.encode_jpeg(input, quality)
return output return output
......
...@@ -94,16 +94,16 @@ def write_video( ...@@ -94,16 +94,16 @@ def write_video(
if audio_array is not None: if audio_array is not None:
audio_format_dtypes = { audio_format_dtypes = {
'dbl': '<f8', "dbl": "<f8",
'dblp': '<f8', "dblp": "<f8",
'flt': '<f4', "flt": "<f4",
'fltp': '<f4', "fltp": "<f4",
's16': '<i2', "s16": "<i2",
's16p': '<i2', "s16p": "<i2",
's32': '<i4', "s32": "<i4",
's32p': '<i4', "s32p": "<i4",
'u8': 'u1', "u8": "u1",
'u8p': 'u1', "u8p": "u1",
} }
a_stream = container.add_stream(audio_codec, rate=audio_fps) a_stream = container.add_stream(audio_codec, rate=audio_fps)
a_stream.options = audio_options or {} a_stream.options = audio_options or {}
...@@ -115,9 +115,7 @@ def write_video( ...@@ -115,9 +115,7 @@ def write_video(
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt]) format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype) audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype)
frame = av.AudioFrame.from_ndarray( frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
audio_array, format=audio_sample_fmt, layout=audio_layout
)
frame.sample_rate = audio_fps frame.sample_rate = audio_fps
...@@ -207,9 +205,7 @@ def _read_from_stream( ...@@ -207,9 +205,7 @@ def _read_from_stream(
# TODO add a warning # TODO add a warning
pass pass
# ensure that the results are sorted wrt the pts # ensure that the results are sorted wrt the pts
result = [ result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset
]
if len(frames) > 0 and start_offset > 0 and start_offset not in frames: if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
# if there is no frame that exactly matches the pts of start_offset # if there is no frame that exactly matches the pts of start_offset
# add the last frame smaller than start_offset, to guarantee that # add the last frame smaller than start_offset, to guarantee that
...@@ -264,7 +260,7 @@ def read_video( ...@@ -264,7 +260,7 @@ def read_video(
from torchvision import get_video_backend from torchvision import get_video_backend
if not os.path.exists(filename): if not os.path.exists(filename):
raise RuntimeError(f'File not found: {filename}') raise RuntimeError(f"File not found: {filename}")
if get_video_backend() != "pyav": if get_video_backend() != "pyav":
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit) return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
...@@ -276,8 +272,7 @@ def read_video( ...@@ -276,8 +272,7 @@ def read_video(
if end_pts < start_pts: if end_pts < start_pts:
raise ValueError( raise ValueError(
"end_pts should be larger than start_pts, got " "end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)
"start_pts={} and end_pts={}".format(start_pts, end_pts)
) )
info = {} info = {}
...@@ -295,8 +290,7 @@ def read_video( ...@@ -295,8 +290,7 @@ def read_video(
elif container.streams.audio: elif container.streams.audio:
time_base = container.streams.audio[0].time_base time_base = container.streams.audio[0].time_base
# video_timebase is the default time_base # video_timebase is the default time_base
start_pts, end_pts, pts_unit = _video_opt._convert_to_sec( start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base)
start_pts, end_pts, pts_unit, time_base)
if container.streams.video: if container.streams.video:
video_frames = _read_from_stream( video_frames = _read_from_stream(
container, container,
...@@ -337,7 +331,7 @@ def read_video( ...@@ -337,7 +331,7 @@ def read_video(
if aframes_list: if aframes_list:
aframes = np.concatenate(aframes_list, 1) aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes) aframes = torch.as_tensor(aframes)
if pts_unit == 'sec': if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase))) start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"): if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
......
...@@ -10,8 +10,8 @@ from .mnasnet import * ...@@ -10,8 +10,8 @@ from .mnasnet import *
from .shufflenetv2 import * from .shufflenetv2 import *
from .efficientnet import * from .efficientnet import *
from .regnet import * from .regnet import *
from . import segmentation
from . import detection from . import detection
from . import video
from . import quantization
from . import feature_extraction from . import feature_extraction
from . import quantization
from . import segmentation
from . import video
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Optional
from torch import nn from torch import nn
from typing import Dict, Optional
class IntermediateLayerGetter(nn.ModuleDict): class IntermediateLayerGetter(nn.ModuleDict):
...@@ -35,6 +35,7 @@ class IntermediateLayerGetter(nn.ModuleDict): ...@@ -35,6 +35,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
>>> [('feat1', torch.Size([1, 64, 56, 56])), >>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))] >>> ('feat2', torch.Size([1, 256, 14, 14]))]
""" """
_version = 2 _version = 2
__annotations__ = { __annotations__ = {
"return_layers": Dict[str, str], "return_layers": Dict[str, str],
......
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
from typing import Any
__all__ = ['AlexNet', 'alexnet'] __all__ = ["AlexNet", "alexnet"]
model_urls = { model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-7be5be79.pth', "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
} }
class AlexNet(nn.Module): class AlexNet(nn.Module):
def __init__(self, num_classes: int = 1000) -> None: def __init__(self, num_classes: int = 1000) -> None:
super(AlexNet, self).__init__() super(AlexNet, self).__init__()
self.features = nn.Sequential( self.features = nn.Sequential(
...@@ -61,7 +62,6 @@ def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> A ...@@ -61,7 +62,6 @@ def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> A
""" """
model = AlexNet(**kwargs) model = AlexNet(**kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls['alexnet'], state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
import re import re
from collections import OrderedDict
from typing import Any, List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from collections import OrderedDict
from .._internally_replaced_utils import load_state_dict_from_url
from torch import Tensor from torch import Tensor
from typing import Any, List, Tuple
from .._internally_replaced_utils import load_state_dict_from_url
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] __all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"]
model_urls = { model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
} }
class _DenseLayer(nn.Module): class _DenseLayer(nn.Module):
def __init__( def __init__(
self, self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
num_input_features: int,
growth_rate: int,
bn_size: int,
drop_rate: float,
memory_efficient: bool = False
) -> None: ) -> None:
super(_DenseLayer, self).__init__() super(_DenseLayer, self).__init__()
self.norm1: nn.BatchNorm2d self.norm1: nn.BatchNorm2d
self.add_module('norm1', nn.BatchNorm2d(num_input_features)) self.add_module("norm1", nn.BatchNorm2d(num_input_features))
self.relu1: nn.ReLU self.relu1: nn.ReLU
self.add_module('relu1', nn.ReLU(inplace=True)) self.add_module("relu1", nn.ReLU(inplace=True))
self.conv1: nn.Conv2d self.conv1: nn.Conv2d
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * self.add_module(
growth_rate, kernel_size=1, stride=1, "conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
bias=False)) )
self.norm2: nn.BatchNorm2d self.norm2: nn.BatchNorm2d
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)) self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
self.relu2: nn.ReLU self.relu2: nn.ReLU
self.add_module('relu2', nn.ReLU(inplace=True)) self.add_module("relu2", nn.ReLU(inplace=True))
self.conv2: nn.Conv2d self.conv2: nn.Conv2d
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, self.add_module(
kernel_size=3, stride=1, padding=1, "conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
bias=False)) )
self.drop_rate = float(drop_rate) self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient self.memory_efficient = memory_efficient
...@@ -93,8 +90,7 @@ class _DenseLayer(nn.Module): ...@@ -93,8 +90,7 @@ class _DenseLayer(nn.Module):
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0: if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate, new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
training=self.training)
return new_features return new_features
...@@ -108,7 +104,7 @@ class _DenseBlock(nn.ModuleDict): ...@@ -108,7 +104,7 @@ class _DenseBlock(nn.ModuleDict):
bn_size: int, bn_size: int,
growth_rate: int, growth_rate: int,
drop_rate: float, drop_rate: float,
memory_efficient: bool = False memory_efficient: bool = False,
) -> None: ) -> None:
super(_DenseBlock, self).__init__() super(_DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
...@@ -119,7 +115,7 @@ class _DenseBlock(nn.ModuleDict): ...@@ -119,7 +115,7 @@ class _DenseBlock(nn.ModuleDict):
drop_rate=drop_rate, drop_rate=drop_rate,
memory_efficient=memory_efficient, memory_efficient=memory_efficient,
) )
self.add_module('denselayer%d' % (i + 1), layer) self.add_module("denselayer%d" % (i + 1), layer)
def forward(self, init_features: Tensor) -> Tensor: def forward(self, init_features: Tensor) -> Tensor:
features = [init_features] features = [init_features]
...@@ -132,11 +128,10 @@ class _DenseBlock(nn.ModuleDict): ...@@ -132,11 +128,10 @@ class _DenseBlock(nn.ModuleDict):
class _Transition(nn.Sequential): class _Transition(nn.Sequential):
def __init__(self, num_input_features: int, num_output_features: int) -> None: def __init__(self, num_input_features: int, num_output_features: int) -> None:
super(_Transition, self).__init__() super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features)) self.add_module("norm", nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True)) self.add_module("relu", nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
kernel_size=1, stride=1, bias=False)) self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(nn.Module): class DenseNet(nn.Module):
...@@ -163,19 +158,22 @@ class DenseNet(nn.Module): ...@@ -163,19 +158,22 @@ class DenseNet(nn.Module):
bn_size: int = 4, bn_size: int = 4,
drop_rate: float = 0, drop_rate: float = 0,
num_classes: int = 1000, num_classes: int = 1000,
memory_efficient: bool = False memory_efficient: bool = False,
) -> None: ) -> None:
super(DenseNet, self).__init__() super(DenseNet, self).__init__()
# First convolution # First convolution
self.features = nn.Sequential(OrderedDict([ self.features = nn.Sequential(
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, OrderedDict(
padding=3, bias=False)), [
('norm0', nn.BatchNorm2d(num_init_features)), ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
('relu0', nn.ReLU(inplace=True)), ("norm0", nn.BatchNorm2d(num_init_features)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ("relu0", nn.ReLU(inplace=True)),
])) ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]
)
)
# Each denseblock # Each denseblock
num_features = num_init_features num_features = num_init_features
...@@ -186,18 +184,17 @@ class DenseNet(nn.Module): ...@@ -186,18 +184,17 @@ class DenseNet(nn.Module):
bn_size=bn_size, bn_size=bn_size,
growth_rate=growth_rate, growth_rate=growth_rate,
drop_rate=drop_rate, drop_rate=drop_rate,
memory_efficient=memory_efficient memory_efficient=memory_efficient,
) )
self.features.add_module('denseblock%d' % (i + 1), block) self.features.add_module("denseblock%d" % (i + 1), block)
num_features = num_features + num_layers * growth_rate num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1: if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features, trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
num_output_features=num_features // 2) self.features.add_module("transition%d" % (i + 1), trans)
self.features.add_module('transition%d' % (i + 1), trans)
num_features = num_features // 2 num_features = num_features // 2
# Final batch norm # Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features)) self.features.add_module("norm5", nn.BatchNorm2d(num_features))
# Linear layer # Linear layer
self.classifier = nn.Linear(num_features, num_classes) self.classifier = nn.Linear(num_features, num_classes)
...@@ -227,7 +224,8 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None: ...@@ -227,7 +224,8 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
# They are also in the checkpoints in model_urls. This pattern is used # They are also in the checkpoints in model_urls. This pattern is used
# to find such keys. # to find such keys.
pattern = re.compile( pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
state_dict = load_state_dict_from_url(model_url, progress=progress) state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
...@@ -246,7 +244,7 @@ def _densenet( ...@@ -246,7 +244,7 @@ def _densenet(
num_init_features: int, num_init_features: int,
pretrained: bool, pretrained: bool,
progress: bool, progress: bool,
**kwargs: Any **kwargs: Any,
) -> DenseNet: ) -> DenseNet:
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained: if pretrained:
...@@ -265,8 +263,7 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -265,8 +263,7 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_. but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
""" """
return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs)
**kwargs)
def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
...@@ -280,8 +277,7 @@ def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -280,8 +277,7 @@ def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_. but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
""" """
return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs)
**kwargs)
def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
...@@ -295,8 +291,7 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -295,8 +291,7 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_. but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
""" """
return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs)
**kwargs)
def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet: def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
...@@ -310,5 +305,4 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) ...@@ -310,5 +305,4 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_. but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
""" """
return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs)
**kwargs)
import math import math
import torch
from collections import OrderedDict from collections import OrderedDict
from torch import Tensor
from typing import List, Tuple from typing import List, Tuple
import torch
from torch import Tensor
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
...@@ -61,12 +60,8 @@ class BalancedPositiveNegativeSampler(object): ...@@ -61,12 +60,8 @@ class BalancedPositiveNegativeSampler(object):
neg_idx_per_image = negative[perm2] neg_idx_per_image = negative[perm2]
# create binary mask from indices # create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like( pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
matched_idxs_per_image, dtype=torch.uint8 neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
)
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
pos_idx_per_image_mask[pos_idx_per_image] = 1 pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1 neg_idx_per_image_mask[neg_idx_per_image] = 1
...@@ -132,7 +127,7 @@ class BoxCoder(object): ...@@ -132,7 +127,7 @@ class BoxCoder(object):
the representation used for training the regressors. the representation used for training the regressors.
""" """
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)): def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
# type: (Tuple[float, float, float, float], float) -> None # type: (Tuple[float, float, float, float], float) -> None
""" """
Args: Args:
...@@ -177,9 +172,7 @@ class BoxCoder(object): ...@@ -177,9 +172,7 @@ class BoxCoder(object):
box_sum += val box_sum += val
if box_sum > 0: if box_sum > 0:
rel_codes = rel_codes.reshape(box_sum, -1) rel_codes = rel_codes.reshape(box_sum, -1)
pred_boxes = self.decode_single( pred_boxes = self.decode_single(rel_codes, concat_boxes)
rel_codes, concat_boxes
)
if box_sum > 0: if box_sum > 0:
pred_boxes = pred_boxes.reshape(box_sum, -1, 4) pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes return pred_boxes
...@@ -247,8 +240,8 @@ class Matcher(object): ...@@ -247,8 +240,8 @@ class Matcher(object):
BETWEEN_THRESHOLDS = -2 BETWEEN_THRESHOLDS = -2
__annotations__ = { __annotations__ = {
'BELOW_LOW_THRESHOLD': int, "BELOW_LOW_THRESHOLD": int,
'BETWEEN_THRESHOLDS': int, "BETWEEN_THRESHOLDS": int,
} }
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False): def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
...@@ -287,13 +280,9 @@ class Matcher(object): ...@@ -287,13 +280,9 @@ class Matcher(object):
if match_quality_matrix.numel() == 0: if match_quality_matrix.numel() == 0:
# empty targets or proposals not supported during training # empty targets or proposals not supported during training
if match_quality_matrix.shape[0] == 0: if match_quality_matrix.shape[0] == 0:
raise ValueError( raise ValueError("No ground-truth boxes available for one of the images " "during training")
"No ground-truth boxes available for one of the images "
"during training")
else: else:
raise ValueError( raise ValueError("No proposal boxes available for one of the images " "during training")
"No proposal boxes available for one of the images "
"during training")
# match_quality_matrix is M (gt) x N (predicted) # match_quality_matrix is M (gt) x N (predicted)
# Max over gt elements (dim 0) to find best gt candidate for each prediction # Max over gt elements (dim 0) to find best gt candidate for each prediction
...@@ -305,9 +294,7 @@ class Matcher(object): ...@@ -305,9 +294,7 @@ class Matcher(object):
# Assign candidate matches with low quality to negative (unassigned) values # Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & ( between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
matched_vals < self.high_threshold
)
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
matches[between_thresholds] = self.BETWEEN_THRESHOLDS matches[between_thresholds] = self.BETWEEN_THRESHOLDS
...@@ -328,9 +315,7 @@ class Matcher(object): ...@@ -328,9 +315,7 @@ class Matcher(object):
# For each gt, find the prediction with which it has highest quality # For each gt, find the prediction with which it has highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find highest quality match available, even if it is low, including ties # Find highest quality match available, even if it is low, including ties
gt_pred_pairs_of_highest_quality = torch.where( gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
match_quality_matrix == highest_quality_foreach_gt[:, None]
)
# Example gt_pred_pairs_of_highest_quality: # Example gt_pred_pairs_of_highest_quality:
# tensor([[ 0, 39796], # tensor([[ 0, 39796],
# [ 1, 32055], # [ 1, 32055],
...@@ -350,7 +335,6 @@ class Matcher(object): ...@@ -350,7 +335,6 @@ class Matcher(object):
class SSDMatcher(Matcher): class SSDMatcher(Matcher):
def __init__(self, threshold): def __init__(self, threshold):
super().__init__(threshold, threshold, allow_low_quality_matches=False) super().__init__(threshold, threshold, allow_low_quality_matches=False)
...@@ -359,9 +343,9 @@ class SSDMatcher(Matcher): ...@@ -359,9 +343,9 @@ class SSDMatcher(Matcher):
# For each gt, find the prediction with which it has the highest quality # For each gt, find the prediction with which it has the highest quality
_, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1) _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
matches[highest_quality_pred_foreach_gt] = torch.arange(highest_quality_pred_foreach_gt.size(0), matches[highest_quality_pred_foreach_gt] = torch.arange(
dtype=torch.int64, highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
device=highest_quality_pred_foreach_gt.device) )
return matches return matches
...@@ -405,7 +389,7 @@ def retrieve_out_channels(model, size): ...@@ -405,7 +389,7 @@ def retrieve_out_channels(model, size):
tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device) tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
features = model(tmp_img) features = model(tmp_img)
if isinstance(features, torch.Tensor): if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)]) features = OrderedDict([("0", features)])
out_channels = [x.size(1) for x in features.values()] out_channels = [x.size(1) for x in features.values()]
if in_training: if in_training:
......
import math import math
from typing import List, Optional
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from typing import List, Optional
from .image_list import ImageList from .image_list import ImageList
...@@ -48,15 +49,21 @@ class AnchorGenerator(nn.Module): ...@@ -48,15 +49,21 @@ class AnchorGenerator(nn.Module):
self.sizes = sizes self.sizes = sizes
self.aspect_ratios = aspect_ratios self.aspect_ratios = aspect_ratios
self.cell_anchors = [self.generate_anchors(size, aspect_ratio) self.cell_anchors = [
for size, aspect_ratio in zip(sizes, aspect_ratios)] self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
]
# TODO: https://github.com/pytorch/pytorch/issues/26792 # TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor. # This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32, def generate_anchors(
device: torch.device = torch.device("cpu")): self,
scales: List[int],
aspect_ratios: List[float],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
scales = torch.as_tensor(scales, dtype=dtype, device=device) scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios) h_ratios = torch.sqrt(aspect_ratios)
...@@ -69,8 +76,7 @@ class AnchorGenerator(nn.Module): ...@@ -69,8 +76,7 @@ class AnchorGenerator(nn.Module):
return base_anchors.round() return base_anchors.round()
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
for cell_anchor in self.cell_anchors]
def num_anchors_per_location(self): def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
...@@ -83,25 +89,21 @@ class AnchorGenerator(nn.Module): ...@@ -83,25 +89,21 @@ class AnchorGenerator(nn.Module):
assert cell_anchors is not None assert cell_anchors is not None
if not (len(grid_sizes) == len(strides) == len(cell_anchors)): if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
raise ValueError("Anchors should be Tuple[Tuple[int]] because each feature " raise ValueError(
"map could potentially have different sizes and aspect ratios. " "Anchors should be Tuple[Tuple[int]] because each feature "
"There needs to be a match between the number of " "map could potentially have different sizes and aspect ratios. "
"feature maps passed and the number of sizes / aspect ratios specified.") "There needs to be a match between the number of "
"feature maps passed and the number of sizes / aspect ratios specified."
for size, stride, base_anchors in zip( )
grid_sizes, strides, cell_anchors
): for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
grid_height, grid_width = size grid_height, grid_width = size
stride_height, stride_width = stride stride_height, stride_width = stride
device = base_anchors.device device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center] # For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange( shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
0, grid_width, dtype=torch.int32, device=device shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.int32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1) shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1) shift_y = shift_y.reshape(-1)
...@@ -109,9 +111,7 @@ class AnchorGenerator(nn.Module): ...@@ -109,9 +111,7 @@ class AnchorGenerator(nn.Module):
# For every (base anchor, output anchor) pair, # For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor. # offset each zero-centered base anchor by the center of the output anchor.
anchors.append( anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
return anchors return anchors
...@@ -119,8 +119,13 @@ class AnchorGenerator(nn.Module): ...@@ -119,8 +119,13 @@ class AnchorGenerator(nn.Module):
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps] grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:] image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), strides = [
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] [
torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device),
]
for g in grid_sizes
]
self.set_cell_anchors(dtype, device) self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides) anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
anchors: List[List[torch.Tensor]] = [] anchors: List[List[torch.Tensor]] = []
...@@ -149,8 +154,15 @@ class DefaultBoxGenerator(nn.Module): ...@@ -149,8 +154,15 @@ class DefaultBoxGenerator(nn.Module):
is applied while the boxes are encoded in format ``(cx, cy, w, h)``. is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
""" """
def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ratio: float = 0.9, def __init__(
scales: Optional[List[float]] = None, steps: Optional[List[int]] = None, clip: bool = True): self,
aspect_ratios: List[List[int]],
min_ratio: float = 0.15,
max_ratio: float = 0.9,
scales: Optional[List[float]] = None,
steps: Optional[List[int]] = None,
clip: bool = True,
):
super().__init__() super().__init__()
if steps is not None: if steps is not None:
assert len(aspect_ratios) == len(steps) assert len(aspect_ratios) == len(steps)
...@@ -172,8 +184,9 @@ class DefaultBoxGenerator(nn.Module): ...@@ -172,8 +184,9 @@ class DefaultBoxGenerator(nn.Module):
self._wh_pairs = self._generate_wh_pairs(num_outputs) self._wh_pairs = self._generate_wh_pairs(num_outputs)
def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32, def _generate_wh_pairs(
device: torch.device = torch.device("cpu")) -> List[Tensor]: self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
) -> List[Tensor]:
_wh_pairs: List[Tensor] = [] _wh_pairs: List[Tensor] = []
for k in range(num_outputs): for k in range(num_outputs):
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k # Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
...@@ -196,8 +209,9 @@ class DefaultBoxGenerator(nn.Module): ...@@ -196,8 +209,9 @@ class DefaultBoxGenerator(nn.Module):
return [2 + 2 * len(r) for r in self.aspect_ratios] return [2 + 2 * len(r) for r in self.aspect_ratios]
# Default Boxes calculation based on page 6 of SSD paper # Default Boxes calculation based on page 6 of SSD paper
def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int], def _grid_default_boxes(
dtype: torch.dtype = torch.float32) -> Tensor: self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
) -> Tensor:
default_boxes = [] default_boxes = []
for k, f_k in enumerate(grid_sizes): for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair # Now add the default boxes for each width-height pair
...@@ -224,12 +238,12 @@ class DefaultBoxGenerator(nn.Module): ...@@ -224,12 +238,12 @@ class DefaultBoxGenerator(nn.Module):
return torch.cat(default_boxes, dim=0) return torch.cat(default_boxes, dim=0)
def __repr__(self) -> str: def __repr__(self) -> str:
s = self.__class__.__name__ + '(' s = self.__class__.__name__ + "("
s += 'aspect_ratios={aspect_ratios}' s += "aspect_ratios={aspect_ratios}"
s += ', clip={clip}' s += ", clip={clip}"
s += ', scales={scales}' s += ", scales={scales}"
s += ', steps={steps}' s += ", steps={steps}"
s += ')' s += ")"
return s.format(**self.__dict__) return s.format(**self.__dict__)
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
...@@ -242,8 +256,13 @@ class DefaultBoxGenerator(nn.Module): ...@@ -242,8 +256,13 @@ class DefaultBoxGenerator(nn.Module):
dboxes = [] dboxes = []
for _ in image_list.image_sizes: for _ in image_list.image_sizes:
dboxes_in_image = default_boxes dboxes_in_image = default_boxes
dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:], dboxes_in_image = torch.cat(
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1) [
dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:],
],
-1,
)
dboxes_in_image[:, 0::2] *= image_size[1] dboxes_in_image[:, 0::2] *= image_size[1]
dboxes_in_image[:, 1::2] *= image_size[0] dboxes_in_image[:, 1::2] *= image_size[0]
dboxes.append(dboxes_in_image) dboxes.append(dboxes_in_image)
......
import warnings import warnings
from torch import nn from torch import nn
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops import misc as misc_nn_ops
from .._utils import IntermediateLayerGetter
from .. import mobilenet from .. import mobilenet
from .. import resnet from .. import resnet
from .._utils import IntermediateLayerGetter
class BackboneWithFPN(nn.Module): class BackboneWithFPN(nn.Module):
...@@ -26,6 +27,7 @@ class BackboneWithFPN(nn.Module): ...@@ -26,6 +27,7 @@ class BackboneWithFPN(nn.Module):
Attributes: Attributes:
out_channels (int): the number of channels in the FPN out_channels (int): the number of channels in the FPN
""" """
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None): def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
super(BackboneWithFPN, self).__init__() super(BackboneWithFPN, self).__init__()
...@@ -52,7 +54,7 @@ def resnet_fpn_backbone( ...@@ -52,7 +54,7 @@ def resnet_fpn_backbone(
norm_layer=misc_nn_ops.FrozenBatchNorm2d, norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3, trainable_layers=3,
returned_layers=None, returned_layers=None,
extra_blocks=None extra_blocks=None,
): ):
""" """
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone. Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
...@@ -89,15 +91,13 @@ def resnet_fpn_backbone( ...@@ -89,15 +91,13 @@ def resnet_fpn_backbone(
a new list of feature maps and their corresponding names. By a new list of feature maps and their corresponding names. By
default a ``LastLevelMaxPool`` is used. default a ``LastLevelMaxPool`` is used.
""" """
backbone = resnet.__dict__[backbone_name]( backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
pretrained=pretrained,
norm_layer=norm_layer)
# select layers that wont be frozen # select layers that wont be frozen
assert 0 <= trainable_layers <= 5 assert 0 <= trainable_layers <= 5
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
if trainable_layers == 5: if trainable_layers == 5:
layers_to_train.append('bn1') layers_to_train.append("bn1")
for name, parameter in backbone.named_parameters(): for name, parameter in backbone.named_parameters():
if all([not name.startswith(layer) for layer in layers_to_train]): if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False) parameter.requires_grad_(False)
...@@ -108,7 +108,7 @@ def resnet_fpn_backbone( ...@@ -108,7 +108,7 @@ def resnet_fpn_backbone(
if returned_layers is None: if returned_layers is None:
returned_layers = [1, 2, 3, 4] returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5 assert min(returned_layers) > 0 and max(returned_layers) < 5
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)} return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
in_channels_stage2 = backbone.inplanes // 8 in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers] in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
...@@ -123,7 +123,8 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, ...@@ -123,7 +123,8 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
warnings.warn( warnings.warn(
"Changing trainable_backbone_layers has not effect if " "Changing trainable_backbone_layers has not effect if "
"neither pretrained nor pretrained_backbone have been set to True, " "neither pretrained nor pretrained_backbone have been set to True, "
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)) "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)
)
trainable_backbone_layers = max_value trainable_backbone_layers = max_value
# by default freeze first blocks # by default freeze first blocks
...@@ -140,7 +141,7 @@ def mobilenet_backbone( ...@@ -140,7 +141,7 @@ def mobilenet_backbone(
norm_layer=misc_nn_ops.FrozenBatchNorm2d, norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2, trainable_layers=2,
returned_layers=None, returned_layers=None,
extra_blocks=None extra_blocks=None,
): ):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
...@@ -165,7 +166,7 @@ def mobilenet_backbone( ...@@ -165,7 +166,7 @@ def mobilenet_backbone(
if returned_layers is None: if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1] returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)} return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers] in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
......
from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torchvision.ops import MultiScaleRoIAlign from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url from ..._internally_replaced_utils import load_state_dict_from_url
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator from .anchor_utils import AnchorGenerator
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
from .generalized_rcnn import GeneralizedRCNN from .generalized_rcnn import GeneralizedRCNN
from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads from .roi_heads import RoIHeads
from .rpn import RPNHead, RegionProposalNetwork
from .transform import GeneralizedRCNNTransform from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
__all__ = [ __all__ = [
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn", "FasterRCNN",
"fasterrcnn_mobilenet_v3_large_fpn" "fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
] ]
...@@ -141,30 +141,48 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -141,30 +141,48 @@ class FasterRCNN(GeneralizedRCNN):
>>> predictions = model(x) >>> predictions = model(x)
""" """
def __init__(self, backbone, num_classes=None, def __init__(
# transform parameters self,
min_size=800, max_size=1333, backbone,
image_mean=None, image_std=None, num_classes=None,
# RPN parameters # transform parameters
rpn_anchor_generator=None, rpn_head=None, min_size=800,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, max_size=1333,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, image_mean=None,
rpn_nms_thresh=0.7, image_std=None,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, # RPN parameters
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, rpn_anchor_generator=None,
rpn_score_thresh=0.0, rpn_head=None,
# Box parameters rpn_pre_nms_top_n_train=2000,
box_roi_pool=None, box_head=None, box_predictor=None, rpn_pre_nms_top_n_test=1000,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, rpn_post_nms_top_n_train=2000,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, rpn_post_nms_top_n_test=1000,
box_batch_size_per_image=512, box_positive_fraction=0.25, rpn_nms_thresh=0.7,
bbox_reg_weights=None): rpn_fg_iou_thresh=0.7,
rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
):
if not hasattr(backbone, "out_channels"): if not hasattr(backbone, "out_channels"):
raise ValueError( raise ValueError(
"backbone should contain an attribute out_channels " "backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the " "specifying the number of output channels (assumed to be the "
"same for all the levels)") "same for all the levels)"
)
assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
...@@ -174,58 +192,59 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -174,58 +192,59 @@ class FasterRCNN(GeneralizedRCNN):
raise ValueError("num_classes should be None when box_predictor is specified") raise ValueError("num_classes should be None when box_predictor is specified")
else: else:
if box_predictor is None: if box_predictor is None:
raise ValueError("num_classes should not be None when box_predictor " raise ValueError("num_classes should not be None when box_predictor " "is not specified")
"is not specified")
out_channels = backbone.out_channels out_channels = backbone.out_channels
if rpn_anchor_generator is None: if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator( rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
anchor_sizes, aspect_ratios
)
if rpn_head is None: if rpn_head is None:
rpn_head = RPNHead( rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
)
rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
rpn = RegionProposalNetwork( rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head, rpn_anchor_generator,
rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_head,
rpn_batch_size_per_image, rpn_positive_fraction, rpn_fg_iou_thresh,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, rpn_bg_iou_thresh,
score_thresh=rpn_score_thresh) rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_pre_nms_top_n,
rpn_post_nms_top_n,
rpn_nms_thresh,
score_thresh=rpn_score_thresh,
)
if box_roi_pool is None: if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign( box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
if box_head is None: if box_head is None:
resolution = box_roi_pool.output_size[0] resolution = box_roi_pool.output_size[0]
representation_size = 1024 representation_size = 1024
box_head = TwoMLPHead( box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
out_channels * resolution ** 2,
representation_size)
if box_predictor is None: if box_predictor is None:
representation_size = 1024 representation_size = 1024
box_predictor = FastRCNNPredictor( box_predictor = FastRCNNPredictor(representation_size, num_classes)
representation_size,
num_classes)
roi_heads = RoIHeads( roi_heads = RoIHeads(
# Box # Box
box_roi_pool, box_head, box_predictor, box_roi_pool,
box_fg_iou_thresh, box_bg_iou_thresh, box_head,
box_batch_size_per_image, box_positive_fraction, box_predictor,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights, bbox_reg_weights,
box_score_thresh, box_nms_thresh, box_detections_per_img) box_score_thresh,
box_nms_thresh,
box_detections_per_img,
)
if image_mean is None: if image_mean is None:
image_mean = [0.485, 0.456, 0.406] image_mean = [0.485, 0.456, 0.406]
...@@ -286,17 +305,15 @@ class FastRCNNPredictor(nn.Module): ...@@ -286,17 +305,15 @@ class FastRCNNPredictor(nn.Module):
model_urls = { model_urls = {
'fasterrcnn_resnet50_fpn_coco': "fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', "fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
'fasterrcnn_mobilenet_v3_large_320_fpn_coco': "fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth',
'fasterrcnn_mobilenet_v3_large_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth'
} }
def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, def fasterrcnn_resnet50_fpn(
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
""" """
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone. Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
...@@ -362,36 +379,54 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, ...@@ -362,36 +379,54 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
""" """
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained: if pretrained:
# no need to download the backbone if pretrained is set # no need to download the backbone if pretrained is set
pretrained_backbone = False pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs) model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress)
progress=progress)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91, def _fasterrcnn_mobilenet_v3_large_fpn(
pretrained_backbone=True, trainable_backbone_layers=None, **kwargs): weights_name,
pretrained=False,
progress=True,
num_classes=91,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
trainable_backbone_layers = _validate_trainable_layers( trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3
)
if pretrained: if pretrained:
pretrained_backbone = False pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True, backbone = mobilenet_backbone(
trainable_layers=trainable_backbone_layers) "mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers
)
anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
anchor_sizes = (
(
32,
64,
128,
256,
512,
),
) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), model = FasterRCNN(
**kwargs) backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)
if pretrained: if pretrained:
if model_urls.get(weights_name, None) is None: if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name)) raise ValueError("No checkpoint is available for model {}".format(weights_name))
...@@ -400,8 +435,9 @@ def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress= ...@@ -400,8 +435,9 @@ def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=
return model return model
def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, def fasterrcnn_mobilenet_v3_large_320_fpn(
trainable_backbone_layers=None, **kwargs): pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
""" """
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases. Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
...@@ -433,13 +469,20 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_c ...@@ -433,13 +469,20 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_c
} }
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, return _fasterrcnn_mobilenet_v3_large_fpn(
num_classes=num_classes, pretrained_backbone=pretrained_backbone, weights_name,
trainable_backbone_layers=trainable_backbone_layers, **kwargs) pretrained=pretrained,
progress=progress,
num_classes=num_classes,
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=None, **kwargs): trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
def fasterrcnn_mobilenet_v3_large_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
""" """
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone. Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
...@@ -467,6 +510,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class ...@@ -467,6 +510,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class
} }
kwargs = {**defaults, **kwargs} kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress, return _fasterrcnn_mobilenet_v3_large_fpn(
num_classes=num_classes, pretrained_backbone=pretrained_backbone, weights_name,
trainable_backbone_layers=trainable_backbone_layers, **kwargs) pretrained=pretrained,
progress=progress,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
Implements the Generalized R-CNN framework Implements the Generalized R-CNN framework
""" """
import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List, Dict, Optional, Union
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
import warnings
from typing import Tuple, List, Dict, Optional, Union
class GeneralizedRCNN(nn.Module): class GeneralizedRCNN(nn.Module):
...@@ -61,12 +62,11 @@ class GeneralizedRCNN(nn.Module): ...@@ -61,12 +62,11 @@ class GeneralizedRCNN(nn.Module):
boxes = target["boxes"] boxes = target["boxes"]
if isinstance(boxes, torch.Tensor): if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4: if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor" raise ValueError(
"of shape [N, 4], got {:}.".format( "Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape)
boxes.shape)) )
else: else:
raise ValueError("Expected target boxes to be of type " raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes)))
"Tensor, got {:}.".format(type(boxes)))
original_image_sizes: List[Tuple[int, int]] = [] original_image_sizes: List[Tuple[int, int]] = []
for img in images: for img in images:
...@@ -86,13 +86,14 @@ class GeneralizedRCNN(nn.Module): ...@@ -86,13 +86,14 @@ class GeneralizedRCNN(nn.Module):
# print the first degenerate box # print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist() degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width." raise ValueError(
" Found invalid box {} for target at index {}." "All bounding boxes should have positive height and width."
.format(degen_bb, target_idx)) " Found invalid box {} for target at index {}.".format(degen_bb, target_idx)
)
features = self.backbone(images.tensors) features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor): if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)]) features = OrderedDict([("0", features)])
proposals, proposal_losses = self.rpn(images, features, targets) proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
......
from typing import List, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Tuple
class ImageList(object): class ImageList(object):
...@@ -20,6 +21,6 @@ class ImageList(object): ...@@ -20,6 +21,6 @@ class ImageList(object):
self.tensors = tensors self.tensors = tensors
self.image_sizes = image_sizes self.image_sizes = image_sizes
def to(self, device: torch.device) -> 'ImageList': def to(self, device: torch.device) -> "ImageList":
cast_tensor = self.tensors.to(device) cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes) return ImageList(cast_tensor, self.image_sizes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment