Unverified Commit 5f0edb97 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Add ufmt (usort + black) as code formatter (#4384)



* add ufmt as code formatter

* cleanup

* quote ufmt requirement

* split imports into more groups

* regenerate circleci config

* fix CI

* clarify local testing utils section

* use ufmt pre-commit hook

* split relative imports into local category

* Revert "split relative imports into local category"

This reverts commit f2e224cde2008c56c9347c1f69746d39065cdd51.

* pin black and usort dependencies

* fix local test utils detection

* fix ufmt rev

* add reference utils to local category

* fix usort config

* remove custom categories sorting

* Run pre-commit without fixing flake8

* got a double import in merge
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent e45489b1
import bz2
import gzip
import hashlib
import itertools
import lzma
import os
import os.path
import hashlib
import gzip
import pathlib
import re
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import zipfile
import lzma
import urllib
import urllib.request
import urllib.error
import pathlib
import itertools
import urllib.request
import zipfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar, Dict, IO, Tuple, Iterator
from urllib.parse import urlparse
import torch
from torch.utils.model_zoo import tqdm
......@@ -52,8 +52,8 @@ def gen_bar_updater() -> Callable[[int, int, int], None]:
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
return md5.hexdigest()
......@@ -120,7 +120,7 @@ def download_url(
# check if file is already present locally
if check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
print("Using downloaded and verified file: " + fpath)
return
if _is_remote_location_available():
......@@ -136,13 +136,12 @@ def download_url(
# download the file
try:
print('Downloading ' + url + ' to ' + fpath)
print("Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
if url[:5] == "https":
url = url.replace("https:", "http:")
print("Failed download. Trying https -> http instead." " Downloading " + url + " to " + fpath)
_urlretrieve(url, fpath)
else:
raise e
......@@ -202,6 +201,7 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root)
......@@ -212,15 +212,15 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
os.makedirs(root, exist_ok=True)
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
print("Using downloaded and verified file: " + fpath)
else:
session = requests.Session()
response = session.get(url, params={'id': file_id}, stream=True)
response = session.get(url, params={"id": file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {'id': file_id, 'confirm': token}
params = {"id": file_id, "confirm": token}
response = session.get(url, params=params, stream=True)
# Ideally, one would use response.status_code to check for quota limits, but google drive is not consistent
......@@ -240,20 +240,21 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
)
raise RuntimeError(msg)
_save_response_content(itertools.chain((first_chunk, ), response_content_generator), fpath)
_save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath)
response.close()
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
for key, value in response.cookies.items():
if key.startswith('download_warning'):
if key.startswith("download_warning"):
return value
return None
def _save_response_content(
response_gen: Iterator[bytes], destination: str, # type: ignore[name-defined]
response_gen: Iterator[bytes],
destination: str, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
......@@ -439,7 +440,10 @@ T = TypeVar("T", str, bytes)
def verify_str_arg(
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None,
value: T,
arg: Optional[str] = None,
valid_values: Iterable[T] = None,
custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes):
if arg is None:
......@@ -456,10 +460,8 @@ def verify_str_arg(
if custom_msg is not None:
msg = custom_msg
else:
msg = ("Unknown value '{value}' for argument {arg}. "
"Valid values are {{{valid_values}}}.")
msg = msg.format(value=value, arg=arg,
valid_values=iterable_to_str(valid_values))
msg = "Unknown value '{value}' for argument {arg}. " "Valid values are {{{valid_values}}}."
msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values))
raise ValueError(msg)
return value
......@@ -206,14 +206,14 @@ class VideoClips(object):
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
idxs = VideoClips._resample_video_idx(
int(math.floor(total_frames)), fps, frame_rate
)
idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[idxs]
clips = unfold(video_pts, num_frames, step)
if not clips.numel():
warnings.warn("There aren't enough frames in the current video to get a clip for the given clip length and "
"frames between clips. The video (and potentially others) will be skipped.")
warnings.warn(
"There aren't enough frames in the current video to get a clip for the given clip length and "
"frames between clips. The video (and potentially others) will be skipped."
)
if isinstance(idxs, slice):
idxs = [idxs] * len(clips)
else:
......@@ -237,9 +237,7 @@ class VideoClips(object):
self.clips = []
self.resampling_idxs = []
for video_pts, fps in zip(self.video_pts, self.video_fps):
clips, idxs = self.compute_clips_for_video(
video_pts, num_frames, step, fps, frame_rate
)
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate)
self.clips.append(clips)
self.resampling_idxs.append(idxs)
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
......@@ -295,10 +293,7 @@ class VideoClips(object):
video_idx (int): index of the video in `video_paths`
"""
if idx >= self.num_clips():
raise IndexError(
"Index {} out of range "
"({} number of clips)".format(idx, self.num_clips())
)
raise IndexError("Index {} out of range " "({} number of clips)".format(idx, self.num_clips()))
video_idx, clip_idx = self.get_clip_location(idx)
video_path = self.video_paths[video_idx]
clip_pts = self.clips[video_idx][clip_idx]
......@@ -314,13 +309,9 @@ class VideoClips(object):
if self._video_height != 0:
raise ValueError("pyav backend doesn't support _video_height != 0")
if self._video_min_dimension != 0:
raise ValueError(
"pyav backend doesn't support _video_min_dimension != 0"
)
raise ValueError("pyav backend doesn't support _video_min_dimension != 0")
if self._video_max_dimension != 0:
raise ValueError(
"pyav backend doesn't support _video_max_dimension != 0"
)
raise ValueError("pyav backend doesn't support _video_max_dimension != 0")
if self._audio_samples != 0:
raise ValueError("pyav backend doesn't support _audio_samples != 0")
......@@ -338,19 +329,11 @@ class VideoClips(object):
audio_start_pts, audio_end_pts = 0, -1
audio_timebase = Fraction(0, 1)
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
if info.has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_start_pts = pts_convert(
video_start_pts, video_timebase, audio_timebase, math.floor
)
audio_end_pts = pts_convert(
video_end_pts, video_timebase, audio_timebase, math.ceil
)
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
audio_start_pts = pts_convert(video_start_pts, video_timebase, audio_timebase, math.floor)
audio_end_pts = pts_convert(video_end_pts, video_timebase, audio_timebase, math.ceil)
audio_fps = info.audio_sample_rate
video, audio, info = _read_video_from_file(
video_path,
......@@ -376,9 +359,7 @@ class VideoClips(object):
resampling_idx = resampling_idx - resampling_idx[0]
video = video[resampling_idx]
info["video_fps"] = self.frame_rate
assert len(video) == self.num_frames, "{} x {}".format(
video.shape, self.num_frames
)
assert len(video) == self.num_frames, "{} x {}".format(video.shape, self.num_frames)
return video, audio, info, video_idx
def __getstate__(self):
......
import os
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.utils.data as data
from typing import Any, Callable, List, Optional, Tuple
class VisionDataset(data.Dataset):
......@@ -22,14 +23,15 @@ class VisionDataset(data.Dataset):
:attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive.
"""
_repr_indent = 4
def __init__(
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
self,
root: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
torch._C._log_api_usage_once(f"torchvision.datasets.{self.__class__.__name__}")
if isinstance(root, torch._six.string_classes):
......@@ -39,8 +41,7 @@ class VisionDataset(data.Dataset):
has_transforms = transforms is not None
has_separate_transform = transform is not None or target_transform is not None
if has_transforms and has_separate_transform:
raise ValueError("Only transforms or transform/target_transform can "
"be passed as argument")
raise ValueError("Only transforms or transform/target_transform can " "be passed as argument")
# for backwards-compatibility
self.transform = transform
......@@ -72,12 +73,11 @@ class VisionDataset(data.Dataset):
if hasattr(self, "transforms") and self.transforms is not None:
body += [repr(self.transforms)]
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
return "\n".join(lines)
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def extra_repr(self) -> str:
return ""
......@@ -97,16 +97,13 @@ class StandardTransform(object):
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
return ["{}{}".format(head, lines[0])] + ["{}{}".format(" " * len(head), line) for line in lines[1:]]
def __repr__(self) -> str:
body = [self.__class__.__name__]
if self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transform: ")
body += self._format_transform_repr(self.transform, "Transform: ")
if self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transform: ")
body += self._format_transform_repr(self.target_transform, "Target transform: ")
return '\n'.join(body)
return "\n".join(body)
import os
import collections
from .vision import VisionDataset
import os
from xml.etree.ElementTree import Element as ET_Element
from .vision import VisionDataset
try:
from defusedxml.ElementTree import parse as ET_parse
except ImportError:
from xml.etree.ElementTree import parse as ET_parse
from PIL import Image
import warnings
from typing import Any, Callable, Dict, Optional, Tuple, List
from PIL import Image
from .utils import download_and_extract_archive, verify_str_arg
import warnings
DATASET_YEAR_DICT = {
'2012': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
'base_dir': os.path.join('VOCdevkit', 'VOC2012')
"2012": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "6cd6e144f989b92b3379bac3b3de84fd",
"base_dir": os.path.join("VOCdevkit", "VOC2012"),
},
"2011": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar",
"filename": "VOCtrainval_25-May-2011.tar",
"md5": "6c3384ef61512963050cb5d687e5bf1e",
"base_dir": os.path.join("TrainVal", "VOCdevkit", "VOC2011"),
},
'2011': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
'filename': 'VOCtrainval_25-May-2011.tar',
'md5': '6c3384ef61512963050cb5d687e5bf1e',
'base_dir': os.path.join('TrainVal', 'VOCdevkit', 'VOC2011')
"2010": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar",
"filename": "VOCtrainval_03-May-2010.tar",
"md5": "da459979d0c395079b5c75ee67908abb",
"base_dir": os.path.join("VOCdevkit", "VOC2010"),
},
'2010': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
'filename': 'VOCtrainval_03-May-2010.tar',
'md5': 'da459979d0c395079b5c75ee67908abb',
'base_dir': os.path.join('VOCdevkit', 'VOC2010')
"2009": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar",
"filename": "VOCtrainval_11-May-2009.tar",
"md5": "59065e4b188729180974ef6572f6a212",
"base_dir": os.path.join("VOCdevkit", "VOC2009"),
},
'2009': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
'filename': 'VOCtrainval_11-May-2009.tar',
'md5': '59065e4b188729180974ef6572f6a212',
'base_dir': os.path.join('VOCdevkit', 'VOC2009')
"2008": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar",
"filename": "VOCtrainval_11-May-2012.tar",
"md5": "2629fa636546599198acfcfbfcf1904a",
"base_dir": os.path.join("VOCdevkit", "VOC2008"),
},
'2008': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '2629fa636546599198acfcfbfcf1904a',
'base_dir': os.path.join('VOCdevkit', 'VOC2008')
"2007": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar",
"filename": "VOCtrainval_06-Nov-2007.tar",
"md5": "c52e279531787c972589f7e41ab4ae64",
"base_dir": os.path.join("VOCdevkit", "VOC2007"),
},
'2007': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'filename': 'VOCtrainval_06-Nov-2007.tar',
'md5': 'c52e279531787c972589f7e41ab4ae64',
'base_dir': os.path.join('VOCdevkit', 'VOC2007')
"2007-test": {
"url": "http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar",
"filename": "VOCtest_06-Nov-2007.tar",
"md5": "b6e924de25625d8de591ea690078ad9f",
"base_dir": os.path.join("VOCdevkit", "VOC2007"),
},
'2007-test': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
'filename': 'VOCtest_06-Nov-2007.tar',
'md5': 'b6e924de25625d8de591ea690078ad9f',
'base_dir': os.path.join('VOCdevkit', 'VOC2007')
}
}
......
from PIL import Image
import os
from os.path import abspath, expanduser
import torch
from typing import Any, Callable, List, Dict, Optional, Tuple, Union
from .utils import check_integrity, download_file_from_google_drive, \
download_and_extract_archive, extract_archive, verify_str_arg
import torch
from PIL import Image
from .utils import (
check_integrity,
download_file_from_google_drive,
download_and_extract_archive,
extract_archive,
verify_str_arg,
)
from .vision import VisionDataset
......@@ -40,25 +47,25 @@ class WIDERFace(VisionDataset):
# File ID MD5 Hash Filename
("0B6eKvaijfFUDQUUwd21EckhUbWs", "3fedf70df600953d25982bcd13d91ba2", "WIDER_train.zip"),
("0B6eKvaijfFUDd3dIRmpvSk8tLUk", "dfa7d7e790efa35df3788964cf0bbaea", "WIDER_val.zip"),
("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip")
("0B6eKvaijfFUDbW4tdGpaYjgzZkU", "e5d8f4248ed24c334bbd12f49c29dd40", "WIDER_test.zip"),
]
ANNOTATIONS_FILE = (
"http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip",
"0e3767bcf0e326556d407bf5bff5d27c",
"wider_face_split.zip"
"wider_face_split.zip",
)
def __init__(
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
self,
root: str,
split: str = "train",
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(WIDERFace, self).__init__(root=os.path.join(root, self.BASE_FOLDER),
transform=transform,
target_transform=target_transform)
super(WIDERFace, self).__init__(
root=os.path.join(root, self.BASE_FOLDER), transform=transform, target_transform=target_transform
)
# check arguments
self.split = verify_str_arg(split, "split", ("train", "val", "test"))
......@@ -66,8 +73,9 @@ class WIDERFace(VisionDataset):
self.download()
if not self._check_integrity():
raise RuntimeError("Dataset not found or corrupted. " +
"You can use download=True to download and prepare it")
raise RuntimeError(
"Dataset not found or corrupted. " + "You can use download=True to download and prepare it"
)
self.img_info: List[Dict[str, Union[str, Dict[str, torch.Tensor]]]] = []
if self.split in ("train", "val"):
......@@ -102,7 +110,7 @@ class WIDERFace(VisionDataset):
def extra_repr(self) -> str:
lines = ["Split: {split}"]
return '\n'.join(lines).format(**self.__dict__)
return "\n".join(lines).format(**self.__dict__)
def parse_train_val_annotations_file(self) -> None:
filename = "wider_face_train_bbx_gt.txt" if self.split == "train" else "wider_face_val_bbx_gt.txt"
......@@ -133,16 +141,20 @@ class WIDERFace(VisionDataset):
box_annotation_line = False
file_name_line = True
labels_tensor = torch.tensor(labels)
self.img_info.append({
"img_path": img_path,
"annotations": {"bbox": labels_tensor[:, 0:4], # x, y, width, height
"blur": labels_tensor[:, 4],
"expression": labels_tensor[:, 5],
"illumination": labels_tensor[:, 6],
"occlusion": labels_tensor[:, 7],
"pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9]}
})
self.img_info.append(
{
"img_path": img_path,
"annotations": {
"bbox": labels_tensor[:, 0:4], # x, y, width, height
"blur": labels_tensor[:, 4],
"expression": labels_tensor[:, 5],
"illumination": labels_tensor[:, 6],
"occlusion": labels_tensor[:, 7],
"pose": labels_tensor[:, 8],
"invalid": labels_tensor[:, 9],
},
}
)
box_counter = 0
labels.clear()
else:
......@@ -172,7 +184,7 @@ class WIDERFace(VisionDataset):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
print("Files already downloaded and verified")
return
# download and extract image data
......@@ -182,6 +194,6 @@ class WIDERFace(VisionDataset):
extract_archive(filepath)
# download and extract annotation files
download_and_extract_archive(url=self.ANNOTATIONS_FILE[0],
download_root=self.root,
md5=self.ANNOTATIONS_FILE[1])
download_and_extract_archive(
url=self.ANNOTATIONS_FILE[0], download_root=self.root, md5=self.ANNOTATIONS_FILE[1]
)
......@@ -11,12 +11,14 @@ def _has_ops():
try:
lib_path = _get_extension_path('_C')
lib_path = _get_extension_path("_C")
torch.ops.load_library(lib_path)
_HAS_OPS = True
def _has_ops(): # noqa: F811
return True
except (ImportError, OSError):
pass
......@@ -41,6 +43,7 @@ def _check_cuda_version():
if not _HAS_OPS:
return -1
import torch
_version = torch.ops.torchvision._cuda_version()
if _version != -1 and torch.version.cuda is not None:
tv_version = str(_version)
......@@ -51,14 +54,17 @@ def _check_cuda_version():
tv_major = int(tv_version[0:2])
tv_minor = int(tv_version[3])
t_version = torch.version.cuda
t_version = t_version.split('.')
t_version = t_version.split(".")
t_major = int(t_version[0])
t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor:
raise RuntimeError("Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install."
.format(t_major, t_minor, tv_major, tv_minor))
raise RuntimeError(
"Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"PyTorch has CUDA Version={}.{} and torchvision has CUDA Version={}.{}. "
"Please reinstall the torchvision that matches your PyTorch install.".format(
t_major, t_minor, tv_major, tv_minor
)
)
return _version
......
import torch
from typing import Any, Dict, Iterator
import torch
from ._video_opt import (
Timebase,
VideoMetaData,
......@@ -12,11 +13,6 @@ from ._video_opt import (
_read_video_timestamps_from_file,
_read_video_timestamps_from_memory,
)
from .video import (
read_video,
read_video_timestamps,
write_video,
)
from .image import (
ImageReadMode,
decode_image,
......@@ -30,6 +26,11 @@ from .image import (
write_jpeg,
write_png,
)
from .video import (
read_video,
read_video_timestamps,
write_video,
)
if _HAS_VIDEO_OPT:
......@@ -127,10 +128,10 @@ class VideoReader:
raise StopIteration
return {"data": frame, "pts": pts}
def __iter__(self) -> Iterator['VideoReader']:
def __iter__(self) -> Iterator["VideoReader"]:
return self
def seek(self, time_s: float) -> 'VideoReader':
def seek(self, time_s: float) -> "VideoReader":
"""Seek within current stream.
Args:
......
import math
import os
import warnings
......@@ -12,7 +11,7 @@ from .._internally_replaced_utils import _get_extension_path
try:
lib_path = _get_extension_path('video_reader')
lib_path = _get_extension_path("video_reader")
torch.ops.load_library(lib_path)
_HAS_VIDEO_OPT = True
except (ImportError, OSError):
......@@ -90,9 +89,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
"""
meta = VideoMetaData()
if vtimebase.numel() > 0:
meta.video_timebase = Timebase(
int(vtimebase[0].item()), int(vtimebase[1].item())
)
meta.video_timebase = Timebase(int(vtimebase[0].item()), int(vtimebase[1].item()))
timebase = vtimebase[0].item() / float(vtimebase[1].item())
if vduration.numel() > 0:
meta.has_video = True
......@@ -100,9 +97,7 @@ def _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration):
if vfps.numel() > 0:
meta.video_fps = float(vfps.item())
if atimebase.numel() > 0:
meta.audio_timebase = Timebase(
int(atimebase[0].item()), int(atimebase[1].item())
)
meta.audio_timebase = Timebase(int(atimebase[0].item()), int(atimebase[1].item()))
timebase = atimebase[0].item() / float(atimebase[1].item())
if aduration.numel() > 0:
meta.has_audio = True
......@@ -216,10 +211,7 @@ def _read_video_from_file(
audio_timebase.numerator,
audio_timebase.denominator,
)
vframes, _vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
if aframes.numel() > 0:
# when audio stream is found
......@@ -254,8 +246,7 @@ def _read_video_timestamps_from_file(filename):
0, # audio_timebase_num
1, # audio_timebase_den
)
_vframes, vframe_pts, vtimebase, vfps, vduration, \
_aframes, aframe_pts, atimebase, asample_rate, aduration = result
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
......@@ -372,10 +363,7 @@ def _read_video_from_memory(
audio_timebase_denominator,
)
vframes, _vframe_pts, vtimebase, vfps, vduration, \
aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
vframes, _vframe_pts, vtimebase, vfps, vduration, aframes, aframe_pts, atimebase, asample_rate, aduration = result
if aframes.numel() > 0:
# when audio stream is found
......@@ -413,10 +401,7 @@ def _read_video_timestamps_from_memory(video_data):
0, # audio_timebase_num
1, # audio_timebase_den
)
_vframes, vframe_pts, vtimebase, vfps, vduration, \
_aframes, aframe_pts, atimebase, asample_rate, aduration = (
result
)
_vframes, vframe_pts, vtimebase, vfps, vduration, _aframes, aframe_pts, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
vframe_pts = vframe_pts.numpy().tolist()
......@@ -439,10 +424,10 @@ def _probe_video_from_memory(video_data):
def _convert_to_sec(start_pts, end_pts, pts_unit, time_base):
if pts_unit == 'pts':
if pts_unit == "pts":
start_pts = float(start_pts * time_base)
end_pts = float(end_pts * time_base)
pts_unit = 'sec'
pts_unit = "sec"
return start_pts, end_pts, pts_unit
......@@ -467,20 +452,15 @@ def _read_video(filename, start_pts=0, end_pts=None, pts_unit="pts"):
time_base = default_timebase
if has_video:
video_timebase = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
video_timebase = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
time_base = video_timebase
if has_audio:
audio_timebase = Fraction(
info.audio_timebase.numerator, info.audio_timebase.denominator
)
audio_timebase = Fraction(info.audio_timebase.numerator, info.audio_timebase.denominator)
time_base = time_base if time_base else audio_timebase
# video_timebase is the default time_base
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(
start_pts, end_pts, pts_unit, time_base)
start_pts_sec, end_pts_sec, pts_unit = _convert_to_sec(start_pts, end_pts, pts_unit, time_base)
def get_pts(time_base):
start_offset = start_pts_sec
......@@ -527,9 +507,7 @@ def _read_video_timestamps(filename, pts_unit="pts"):
pts, _, info = _read_video_timestamps_from_file(filename)
if pts_unit == "sec":
video_time_base = Fraction(
info.video_timebase.numerator, info.video_timebase.denominator
)
video_time_base = Fraction(info.video_timebase.numerator, info.video_timebase.denominator)
pts = [x * video_time_base for x in pts]
video_fps = info.video_fps if info.has_video else None
......
import torch
from enum import Enum
import torch
from .._internally_replaced_utils import _get_extension_path
try:
lib_path = _get_extension_path('image')
lib_path = _get_extension_path("image")
torch.ops.load_library(lib_path)
except (ImportError, OSError):
pass
......@@ -21,6 +22,7 @@ class ImageReadMode(Enum):
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency.
"""
UNCHANGED = 0
GRAY = 1
GRAY_ALPHA = 2
......@@ -111,8 +113,9 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output)
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: str = 'cpu') -> torch.Tensor:
def decode_jpeg(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format.
......@@ -135,7 +138,7 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG
output (Tensor[image_channels, image_height, image_width])
"""
device = torch.device(device)
if device.type == 'cuda':
if device.type == "cuda":
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
......@@ -158,8 +161,7 @@ def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
JPEG file.
"""
if quality < 1 or quality > 100:
raise ValueError('Image quality should be a positive number '
'between 1 and 100')
raise ValueError("Image quality should be a positive number " "between 1 and 100")
output = torch.ops.image.encode_jpeg(input, quality)
return output
......
......@@ -94,16 +94,16 @@ def write_video(
if audio_array is not None:
audio_format_dtypes = {
'dbl': '<f8',
'dblp': '<f8',
'flt': '<f4',
'fltp': '<f4',
's16': '<i2',
's16p': '<i2',
's32': '<i4',
's32p': '<i4',
'u8': 'u1',
'u8p': 'u1',
"dbl": "<f8",
"dblp": "<f8",
"flt": "<f4",
"fltp": "<f4",
"s16": "<i2",
"s16p": "<i2",
"s32": "<i4",
"s32p": "<i4",
"u8": "u1",
"u8p": "u1",
}
a_stream = container.add_stream(audio_codec, rate=audio_fps)
a_stream.options = audio_options or {}
......@@ -115,9 +115,7 @@ def write_video(
format_dtype = np.dtype(audio_format_dtypes[audio_sample_fmt])
audio_array = torch.as_tensor(audio_array).numpy().astype(format_dtype)
frame = av.AudioFrame.from_ndarray(
audio_array, format=audio_sample_fmt, layout=audio_layout
)
frame = av.AudioFrame.from_ndarray(audio_array, format=audio_sample_fmt, layout=audio_layout)
frame.sample_rate = audio_fps
......@@ -207,9 +205,7 @@ def _read_from_stream(
# TODO add a warning
pass
# ensure that the results are sorted wrt the pts
result = [
frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset
]
result = [frames[i] for i in sorted(frames) if start_offset <= frames[i].pts <= end_offset]
if len(frames) > 0 and start_offset > 0 and start_offset not in frames:
# if there is no frame that exactly matches the pts of start_offset
# add the last frame smaller than start_offset, to guarantee that
......@@ -264,7 +260,7 @@ def read_video(
from torchvision import get_video_backend
if not os.path.exists(filename):
raise RuntimeError(f'File not found: {filename}')
raise RuntimeError(f"File not found: {filename}")
if get_video_backend() != "pyav":
return _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
......@@ -276,8 +272,7 @@ def read_video(
if end_pts < start_pts:
raise ValueError(
"end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts)
"end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)
)
info = {}
......@@ -295,8 +290,7 @@ def read_video(
elif container.streams.audio:
time_base = container.streams.audio[0].time_base
# video_timebase is the default time_base
start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(
start_pts, end_pts, pts_unit, time_base)
start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, pts_unit, time_base)
if container.streams.video:
video_frames = _read_from_stream(
container,
......@@ -337,7 +331,7 @@ def read_video(
if aframes_list:
aframes = np.concatenate(aframes_list, 1)
aframes = torch.as_tensor(aframes)
if pts_unit == 'sec':
if pts_unit == "sec":
start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
if end_pts != float("inf"):
end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
......
......@@ -10,8 +10,8 @@ from .mnasnet import *
from .shufflenetv2 import *
from .efficientnet import *
from .regnet import *
from . import segmentation
from . import detection
from . import video
from . import quantization
from . import feature_extraction
from . import quantization
from . import segmentation
from . import video
from collections import OrderedDict
from typing import Dict, Optional
from torch import nn
from typing import Dict, Optional
class IntermediateLayerGetter(nn.ModuleDict):
......@@ -35,6 +35,7 @@ class IntermediateLayerGetter(nn.ModuleDict):
>>> [('feat1', torch.Size([1, 64, 56, 56])),
>>> ('feat2', torch.Size([1, 256, 14, 14]))]
"""
_version = 2
__annotations__ = {
"return_layers": Dict[str, str],
......
from typing import Any
import torch
import torch.nn as nn
from .._internally_replaced_utils import load_state_dict_from_url
from typing import Any
__all__ = ['AlexNet', 'alexnet']
__all__ = ["AlexNet", "alexnet"]
model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-7be5be79.pth',
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
}
class AlexNet(nn.Module):
def __init__(self, num_classes: int = 1000) -> None:
super(AlexNet, self).__init__()
self.features = nn.Sequential(
......@@ -61,7 +62,6 @@ def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> A
"""
model = AlexNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['alexnet'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress)
model.load_state_dict(state_dict)
return model
import re
from collections import OrderedDict
from typing import Any, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict
from .._internally_replaced_utils import load_state_dict_from_url
from torch import Tensor
from typing import Any, List, Tuple
from .._internally_replaced_utils import load_state_dict_from_url
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
__all__ = ["DenseNet", "densenet121", "densenet169", "densenet201", "densenet161"]
model_urls = {
'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
}
class _DenseLayer(nn.Module):
def __init__(
self,
num_input_features: int,
growth_rate: int,
bn_size: int,
drop_rate: float,
memory_efficient: bool = False
self, num_input_features: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False
) -> None:
super(_DenseLayer, self).__init__()
self.norm1: nn.BatchNorm2d
self.add_module('norm1', nn.BatchNorm2d(num_input_features))
self.add_module("norm1", nn.BatchNorm2d(num_input_features))
self.relu1: nn.ReLU
self.add_module('relu1', nn.ReLU(inplace=True))
self.add_module("relu1", nn.ReLU(inplace=True))
self.conv1: nn.Conv2d
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1,
bias=False))
self.add_module(
"conv1", nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)
)
self.norm2: nn.BatchNorm2d
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
self.relu2: nn.ReLU
self.add_module('relu2', nn.ReLU(inplace=True))
self.add_module("relu2", nn.ReLU(inplace=True))
self.conv2: nn.Conv2d
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1,
bias=False))
self.add_module(
"conv2", nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)
)
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
......@@ -93,8 +90,7 @@ class _DenseLayer(nn.Module):
new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
if self.drop_rate > 0:
new_features = F.dropout(new_features, p=self.drop_rate,
training=self.training)
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
return new_features
......@@ -108,7 +104,7 @@ class _DenseBlock(nn.ModuleDict):
bn_size: int,
growth_rate: int,
drop_rate: float,
memory_efficient: bool = False
memory_efficient: bool = False,
) -> None:
super(_DenseBlock, self).__init__()
for i in range(num_layers):
......@@ -119,7 +115,7 @@ class _DenseBlock(nn.ModuleDict):
drop_rate=drop_rate,
memory_efficient=memory_efficient,
)
self.add_module('denselayer%d' % (i + 1), layer)
self.add_module("denselayer%d" % (i + 1), layer)
def forward(self, init_features: Tensor) -> Tensor:
features = [init_features]
......@@ -132,11 +128,10 @@ class _DenseBlock(nn.ModuleDict):
class _Transition(nn.Sequential):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
kernel_size=1, stride=1, bias=False))
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
self.add_module("norm", nn.BatchNorm2d(num_input_features))
self.add_module("relu", nn.ReLU(inplace=True))
self.add_module("conv", nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))
class DenseNet(nn.Module):
......@@ -163,19 +158,22 @@ class DenseNet(nn.Module):
bn_size: int = 4,
drop_rate: float = 0,
num_classes: int = 1000,
memory_efficient: bool = False
memory_efficient: bool = False,
) -> None:
super(DenseNet, self).__init__()
# First convolution
self.features = nn.Sequential(OrderedDict([
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2,
padding=3, bias=False)),
('norm0', nn.BatchNorm2d(num_init_features)),
('relu0', nn.ReLU(inplace=True)),
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
self.features = nn.Sequential(
OrderedDict(
[
("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
("norm0", nn.BatchNorm2d(num_init_features)),
("relu0", nn.ReLU(inplace=True)),
("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]
)
)
# Each denseblock
num_features = num_init_features
......@@ -186,18 +184,17 @@ class DenseNet(nn.Module):
bn_size=bn_size,
growth_rate=growth_rate,
drop_rate=drop_rate,
memory_efficient=memory_efficient
memory_efficient=memory_efficient,
)
self.features.add_module('denseblock%d' % (i + 1), block)
self.features.add_module("denseblock%d" % (i + 1), block)
num_features = num_features + num_layers * growth_rate
if i != len(block_config) - 1:
trans = _Transition(num_input_features=num_features,
num_output_features=num_features // 2)
self.features.add_module('transition%d' % (i + 1), trans)
trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
self.features.add_module("transition%d" % (i + 1), trans)
num_features = num_features // 2
# Final batch norm
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
self.features.add_module("norm5", nn.BatchNorm2d(num_features))
# Linear layer
self.classifier = nn.Linear(num_features, num_classes)
......@@ -227,7 +224,8 @@ def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
# They are also in the checkpoints in model_urls. This pattern is used
# to find such keys.
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
state_dict = load_state_dict_from_url(model_url, progress=progress)
for key in list(state_dict.keys()):
......@@ -246,7 +244,7 @@ def _densenet(
num_init_features: int,
pretrained: bool,
progress: bool,
**kwargs: Any
**kwargs: Any,
) -> DenseNet:
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained:
......@@ -265,8 +263,7 @@ def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
"""
return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress,
**kwargs)
return _densenet("densenet121", 32, (6, 12, 24, 16), 64, pretrained, progress, **kwargs)
def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
......@@ -280,8 +277,7 @@ def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
"""
return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress,
**kwargs)
return _densenet("densenet161", 48, (6, 12, 36, 24), 96, pretrained, progress, **kwargs)
def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
......@@ -295,8 +291,7 @@ def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
"""
return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress,
**kwargs)
return _densenet("densenet169", 32, (6, 12, 32, 32), 64, pretrained, progress, **kwargs)
def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
......@@ -310,5 +305,4 @@ def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any)
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_.
"""
return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress,
**kwargs)
return _densenet("densenet201", 32, (6, 12, 48, 32), 64, pretrained, progress, **kwargs)
import math
import torch
from collections import OrderedDict
from torch import Tensor
from typing import List, Tuple
import torch
from torch import Tensor
from torchvision.ops.misc import FrozenBatchNorm2d
......@@ -61,12 +60,8 @@ class BalancedPositiveNegativeSampler(object):
neg_idx_per_image = negative[perm2]
# create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
pos_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
neg_idx_per_image_mask = torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1
......@@ -132,7 +127,7 @@ class BoxCoder(object):
the representation used for training the regressors.
"""
def __init__(self, weights, bbox_xform_clip=math.log(1000. / 16)):
def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)):
# type: (Tuple[float, float, float, float], float) -> None
"""
Args:
......@@ -177,9 +172,7 @@ class BoxCoder(object):
box_sum += val
if box_sum > 0:
rel_codes = rel_codes.reshape(box_sum, -1)
pred_boxes = self.decode_single(
rel_codes, concat_boxes
)
pred_boxes = self.decode_single(rel_codes, concat_boxes)
if box_sum > 0:
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes
......@@ -247,8 +240,8 @@ class Matcher(object):
BETWEEN_THRESHOLDS = -2
__annotations__ = {
'BELOW_LOW_THRESHOLD': int,
'BETWEEN_THRESHOLDS': int,
"BELOW_LOW_THRESHOLD": int,
"BETWEEN_THRESHOLDS": int,
}
def __init__(self, high_threshold, low_threshold, allow_low_quality_matches=False):
......@@ -287,13 +280,9 @@ class Matcher(object):
if match_quality_matrix.numel() == 0:
# empty targets or proposals not supported during training
if match_quality_matrix.shape[0] == 0:
raise ValueError(
"No ground-truth boxes available for one of the images "
"during training")
raise ValueError("No ground-truth boxes available for one of the images " "during training")
else:
raise ValueError(
"No proposal boxes available for one of the images "
"during training")
raise ValueError("No proposal boxes available for one of the images " "during training")
# match_quality_matrix is M (gt) x N (predicted)
# Max over gt elements (dim 0) to find best gt candidate for each prediction
......@@ -305,9 +294,7 @@ class Matcher(object):
# Assign candidate matches with low quality to negative (unassigned) values
below_low_threshold = matched_vals < self.low_threshold
between_thresholds = (matched_vals >= self.low_threshold) & (
matched_vals < self.high_threshold
)
between_thresholds = (matched_vals >= self.low_threshold) & (matched_vals < self.high_threshold)
matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD
matches[between_thresholds] = self.BETWEEN_THRESHOLDS
......@@ -328,9 +315,7 @@ class Matcher(object):
# For each gt, find the prediction with which it has highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find highest quality match available, even if it is low, including ties
gt_pred_pairs_of_highest_quality = torch.where(
match_quality_matrix == highest_quality_foreach_gt[:, None]
)
gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
# Example gt_pred_pairs_of_highest_quality:
# tensor([[ 0, 39796],
# [ 1, 32055],
......@@ -350,7 +335,6 @@ class Matcher(object):
class SSDMatcher(Matcher):
def __init__(self, threshold):
super().__init__(threshold, threshold, allow_low_quality_matches=False)
......@@ -359,9 +343,9 @@ class SSDMatcher(Matcher):
# For each gt, find the prediction with which it has the highest quality
_, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1)
matches[highest_quality_pred_foreach_gt] = torch.arange(highest_quality_pred_foreach_gt.size(0),
dtype=torch.int64,
device=highest_quality_pred_foreach_gt.device)
matches[highest_quality_pred_foreach_gt] = torch.arange(
highest_quality_pred_foreach_gt.size(0), dtype=torch.int64, device=highest_quality_pred_foreach_gt.device
)
return matches
......@@ -405,7 +389,7 @@ def retrieve_out_channels(model, size):
tmp_img = torch.zeros((1, 3, size[1], size[0]), device=device)
features = model(tmp_img)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
features = OrderedDict([("0", features)])
out_channels = [x.size(1) for x in features.values()]
if in_training:
......
import math
from typing import List, Optional
import torch
from torch import nn, Tensor
from typing import List, Optional
from .image_list import ImageList
......@@ -48,15 +49,21 @@ class AnchorGenerator(nn.Module):
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = [self.generate_anchors(size, aspect_ratio)
for size, aspect_ratio in zip(sizes, aspect_ratios)]
self.cell_anchors = [
self.generate_anchors(size, aspect_ratio) for size, aspect_ratio in zip(sizes, aspect_ratios)
]
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu")):
def generate_anchors(
self,
scales: List[int],
aspect_ratios: List[float],
dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"),
):
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
......@@ -69,8 +76,7 @@ class AnchorGenerator(nn.Module):
return base_anchors.round()
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device)
for cell_anchor in self.cell_anchors]
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
......@@ -83,25 +89,21 @@ class AnchorGenerator(nn.Module):
assert cell_anchors is not None
if not (len(grid_sizes) == len(strides) == len(cell_anchors)):
raise ValueError("Anchors should be Tuple[Tuple[int]] because each feature "
"map could potentially have different sizes and aspect ratios. "
"There needs to be a match between the number of "
"feature maps passed and the number of sizes / aspect ratios specified.")
for size, stride, base_anchors in zip(
grid_sizes, strides, cell_anchors
):
raise ValueError(
"Anchors should be Tuple[Tuple[int]] because each feature "
"map could potentially have different sizes and aspect ratios. "
"There needs to be a match between the number of "
"feature maps passed and the number of sizes / aspect ratios specified."
)
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.int32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.int32, device=device
) * stride_height
shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width
shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
......@@ -109,9 +111,7 @@ class AnchorGenerator(nn.Module):
# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
return anchors
......@@ -119,8 +119,13 @@ class AnchorGenerator(nn.Module):
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
strides = [
[
torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device),
]
for g in grid_sizes
]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
anchors: List[List[torch.Tensor]] = []
......@@ -149,8 +154,15 @@ class DefaultBoxGenerator(nn.Module):
is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
"""
def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ratio: float = 0.9,
scales: Optional[List[float]] = None, steps: Optional[List[int]] = None, clip: bool = True):
def __init__(
self,
aspect_ratios: List[List[int]],
min_ratio: float = 0.15,
max_ratio: float = 0.9,
scales: Optional[List[float]] = None,
steps: Optional[List[int]] = None,
clip: bool = True,
):
super().__init__()
if steps is not None:
assert len(aspect_ratios) == len(steps)
......@@ -172,8 +184,9 @@ class DefaultBoxGenerator(nn.Module):
self._wh_pairs = self._generate_wh_pairs(num_outputs)
def _generate_wh_pairs(self, num_outputs: int, dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu")) -> List[Tensor]:
def _generate_wh_pairs(
self, num_outputs: int, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu")
) -> List[Tensor]:
_wh_pairs: List[Tensor] = []
for k in range(num_outputs):
# Adding the 2 default width-height pairs for aspect ratio 1 and scale s'k
......@@ -196,8 +209,9 @@ class DefaultBoxGenerator(nn.Module):
return [2 + 2 * len(r) for r in self.aspect_ratios]
# Default Boxes calculation based on page 6 of SSD paper
def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int],
dtype: torch.dtype = torch.float32) -> Tensor:
def _grid_default_boxes(
self, grid_sizes: List[List[int]], image_size: List[int], dtype: torch.dtype = torch.float32
) -> Tensor:
default_boxes = []
for k, f_k in enumerate(grid_sizes):
# Now add the default boxes for each width-height pair
......@@ -224,12 +238,12 @@ class DefaultBoxGenerator(nn.Module):
return torch.cat(default_boxes, dim=0)
def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'aspect_ratios={aspect_ratios}'
s += ', clip={clip}'
s += ', scales={scales}'
s += ', steps={steps}'
s += ')'
s = self.__class__.__name__ + "("
s += "aspect_ratios={aspect_ratios}"
s += ", clip={clip}"
s += ", scales={scales}"
s += ", steps={steps}"
s += ")"
return s.format(**self.__dict__)
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
......@@ -242,8 +256,13 @@ class DefaultBoxGenerator(nn.Module):
dboxes = []
for _ in image_list.image_sizes:
dboxes_in_image = default_boxes
dboxes_in_image = torch.cat([dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:]], -1)
dboxes_in_image = torch.cat(
[
dboxes_in_image[:, :2] - 0.5 * dboxes_in_image[:, 2:],
dboxes_in_image[:, :2] + 0.5 * dboxes_in_image[:, 2:],
],
-1,
)
dboxes_in_image[:, 0::2] *= image_size[1]
dboxes_in_image[:, 1::2] *= image_size[0]
dboxes.append(dboxes_in_image)
......
import warnings
from torch import nn
from torchvision.ops import misc as misc_nn_ops
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.ops import misc as misc_nn_ops
from .._utils import IntermediateLayerGetter
from .. import mobilenet
from .. import resnet
from .._utils import IntermediateLayerGetter
class BackboneWithFPN(nn.Module):
......@@ -26,6 +27,7 @@ class BackboneWithFPN(nn.Module):
Attributes:
out_channels (int): the number of channels in the FPN
"""
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
super(BackboneWithFPN, self).__init__()
......@@ -52,7 +54,7 @@ def resnet_fpn_backbone(
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None
extra_blocks=None,
):
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
......@@ -89,15 +91,13 @@ def resnet_fpn_backbone(
a new list of feature maps and their corresponding names. By
default a ``LastLevelMaxPool`` is used.
"""
backbone = resnet.__dict__[backbone_name](
pretrained=pretrained,
norm_layer=norm_layer)
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
# select layers that wont be frozen
assert 0 <= trainable_layers <= 5
layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers]
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
if trainable_layers == 5:
layers_to_train.append('bn1')
layers_to_train.append("bn1")
for name, parameter in backbone.named_parameters():
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)
......@@ -108,7 +108,7 @@ def resnet_fpn_backbone(
if returned_layers is None:
returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}
return_layers = {f"layer{k}": str(v) for v, k in enumerate(returned_layers)}
in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
......@@ -123,7 +123,8 @@ def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value,
warnings.warn(
"Changing trainable_backbone_layers has not effect if "
"neither pretrained nor pretrained_backbone have been set to True, "
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)
)
trainable_backbone_layers = max_value
# by default freeze first blocks
......@@ -140,7 +141,7 @@ def mobilenet_backbone(
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None
extra_blocks=None,
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
......@@ -165,7 +166,7 @@ def mobilenet_backbone(
if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)}
return_layers = {f"{stage_indices[k]}": str(v) for v, k in enumerate(returned_layers)}
in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
......
from torch import nn
import torch.nn.functional as F
from torch import nn
from torchvision.ops import MultiScaleRoIAlign
from ._utils import overwrite_eps
from ..._internally_replaced_utils import load_state_dict_from_url
from ._utils import overwrite_eps
from .anchor_utils import AnchorGenerator
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
from .generalized_rcnn import GeneralizedRCNN
from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads
from .rpn import RPNHead, RegionProposalNetwork
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
__all__ = [
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn",
"fasterrcnn_mobilenet_v3_large_fpn"
"FasterRCNN",
"fasterrcnn_resnet50_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn",
"fasterrcnn_mobilenet_v3_large_fpn",
]
......@@ -141,30 +141,48 @@ class FasterRCNN(GeneralizedRCNN):
>>> predictions = model(x)
"""
def __init__(self, backbone, num_classes=None,
# transform parameters
min_size=800, max_size=1333,
image_mean=None, image_std=None,
# RPN parameters
rpn_anchor_generator=None, rpn_head=None,
rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5,
box_batch_size_per_image=512, box_positive_fraction=0.25,
bbox_reg_weights=None):
def __init__(
self,
backbone,
num_classes=None,
# transform parameters
min_size=800,
max_size=1333,
image_mean=None,
image_std=None,
# RPN parameters
rpn_anchor_generator=None,
rpn_head=None,
rpn_pre_nms_top_n_train=2000,
rpn_pre_nms_top_n_test=1000,
rpn_post_nms_top_n_train=2000,
rpn_post_nms_top_n_test=1000,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7,
rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256,
rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None,
box_head=None,
box_predictor=None,
box_score_thresh=0.05,
box_nms_thresh=0.5,
box_detections_per_img=100,
box_fg_iou_thresh=0.5,
box_bg_iou_thresh=0.5,
box_batch_size_per_image=512,
box_positive_fraction=0.25,
bbox_reg_weights=None,
):
if not hasattr(backbone, "out_channels"):
raise ValueError(
"backbone should contain an attribute out_channels "
"specifying the number of output channels (assumed to be the "
"same for all the levels)")
"same for all the levels)"
)
assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None)))
assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None)))
......@@ -174,58 +192,59 @@ class FasterRCNN(GeneralizedRCNN):
raise ValueError("num_classes should be None when box_predictor is specified")
else:
if box_predictor is None:
raise ValueError("num_classes should not be None when box_predictor "
"is not specified")
raise ValueError("num_classes should not be None when box_predictor " "is not specified")
out_channels = backbone.out_channels
if rpn_anchor_generator is None:
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
rpn_anchor_generator = AnchorGenerator(
anchor_sizes, aspect_ratios
)
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
if rpn_head is None:
rpn_head = RPNHead(
out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
)
rpn_head = RPNHead(out_channels, rpn_anchor_generator.num_anchors_per_location()[0])
rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test)
rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test)
rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
score_thresh=rpn_score_thresh)
rpn_anchor_generator,
rpn_head,
rpn_fg_iou_thresh,
rpn_bg_iou_thresh,
rpn_batch_size_per_image,
rpn_positive_fraction,
rpn_pre_nms_top_n,
rpn_post_nms_top_n,
rpn_nms_thresh,
score_thresh=rpn_score_thresh,
)
if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
featmap_names=['0', '1', '2', '3'],
output_size=7,
sampling_ratio=2)
box_roi_pool = MultiScaleRoIAlign(featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2)
if box_head is None:
resolution = box_roi_pool.output_size[0]
representation_size = 1024
box_head = TwoMLPHead(
out_channels * resolution ** 2,
representation_size)
box_head = TwoMLPHead(out_channels * resolution ** 2, representation_size)
if box_predictor is None:
representation_size = 1024
box_predictor = FastRCNNPredictor(
representation_size,
num_classes)
box_predictor = FastRCNNPredictor(representation_size, num_classes)
roi_heads = RoIHeads(
# Box
box_roi_pool, box_head, box_predictor,
box_fg_iou_thresh, box_bg_iou_thresh,
box_batch_size_per_image, box_positive_fraction,
box_roi_pool,
box_head,
box_predictor,
box_fg_iou_thresh,
box_bg_iou_thresh,
box_batch_size_per_image,
box_positive_fraction,
bbox_reg_weights,
box_score_thresh, box_nms_thresh, box_detections_per_img)
box_score_thresh,
box_nms_thresh,
box_detections_per_img,
)
if image_mean is None:
image_mean = [0.485, 0.456, 0.406]
......@@ -286,17 +305,15 @@ class FastRCNNPredictor(nn.Module):
model_urls = {
'fasterrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
'fasterrcnn_mobilenet_v3_large_320_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth',
'fasterrcnn_mobilenet_v3_large_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth'
"fasterrcnn_resnet50_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
"fasterrcnn_mobilenet_v3_large_320_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
"fasterrcnn_mobilenet_v3_large_fpn_coco": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
}
def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
def fasterrcnn_resnet50_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
"""
Constructs a Faster R-CNN model with a ResNet-50-FPN backbone.
......@@ -362,36 +379,54 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
if pretrained:
# no need to download the backbone if pretrained is set
pretrained_backbone = False
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers)
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
model = FasterRCNN(backbone, num_classes, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
progress=progress)
state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress)
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model
def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91,
pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
def _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=False,
progress=True,
num_classes=91,
pretrained_backbone=True,
trainable_backbone_layers=None,
**kwargs,
):
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3
)
if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
trainable_layers=trainable_backbone_layers)
anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
backbone = mobilenet_backbone(
"mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers
)
anchor_sizes = (
(
32,
64,
128,
256,
512,
),
) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
**kwargs)
model = FasterRCNN(
backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs
)
if pretrained:
if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name))
......@@ -400,8 +435,9 @@ def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=
return model
def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, **kwargs):
def fasterrcnn_mobilenet_v3_large_320_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
"""
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
......@@ -433,13 +469,20 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_c
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
num_classes=num_classes, pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers, **kwargs)
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, **kwargs):
return _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=pretrained,
progress=progress,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
def fasterrcnn_mobilenet_v3_large_fpn(
pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=None, **kwargs
):
"""
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See
......@@ -467,6 +510,12 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class
}
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
num_classes=num_classes, pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers, **kwargs)
return _fasterrcnn_mobilenet_v3_large_fpn(
weights_name,
pretrained=pretrained,
progress=progress,
num_classes=num_classes,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
**kwargs,
)
......@@ -2,11 +2,12 @@
Implements the Generalized R-CNN framework
"""
import warnings
from collections import OrderedDict
from typing import Tuple, List, Dict, Optional, Union
import torch
from torch import nn, Tensor
import warnings
from typing import Tuple, List, Dict, Optional, Union
class GeneralizedRCNN(nn.Module):
......@@ -61,12 +62,11 @@ class GeneralizedRCNN(nn.Module):
boxes = target["boxes"]
if isinstance(boxes, torch.Tensor):
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
raise ValueError("Expected target boxes to be a tensor"
"of shape [N, 4], got {:}.".format(
boxes.shape))
raise ValueError(
"Expected target boxes to be a tensor" "of shape [N, 4], got {:}.".format(boxes.shape)
)
else:
raise ValueError("Expected target boxes to be of type "
"Tensor, got {:}.".format(type(boxes)))
raise ValueError("Expected target boxes to be of type " "Tensor, got {:}.".format(type(boxes)))
original_image_sizes: List[Tuple[int, int]] = []
for img in images:
......@@ -86,13 +86,14 @@ class GeneralizedRCNN(nn.Module):
# print the first degenerate box
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
.format(degen_bb, target_idx))
raise ValueError(
"All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}.".format(degen_bb, target_idx)
)
features = self.backbone(images.tensors)
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
features = OrderedDict([("0", features)])
proposals, proposal_losses = self.rpn(images, features, targets)
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
......
from typing import List, Tuple
import torch
from torch import Tensor
from typing import List, Tuple
class ImageList(object):
......@@ -20,6 +21,6 @@ class ImageList(object):
self.tensors = tensors
self.image_sizes = image_sizes
def to(self, device: torch.device) -> 'ImageList':
def to(self, device: torch.device) -> "ImageList":
cast_tensor = self.tensors.to(device)
return ImageList(cast_tensor, self.image_sizes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment