Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
import os import os
from typing import Any, Callable, List, Optional, Tuple from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.utils.data as data import torch.utils.data as data
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
...@@ -36,7 +35,7 @@ class VisionDataset(data.Dataset): ...@@ -36,7 +35,7 @@ class VisionDataset(data.Dataset):
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
_log_api_usage_once(self) _log_api_usage_once(self)
if isinstance(root, torch._six.string_classes): if isinstance(root, str):
root = os.path.expanduser(root) root = os.path.expanduser(root)
self.root = root self.root = root
......
...@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset): ...@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset):
{ {
"img_path": img_path, "img_path": img_path,
"annotations": { "annotations": {
"bbox": labels_tensor[:, 0:4], # x, y, width, height "bbox": labels_tensor[:, 0:4].clone(), # x, y, width, height
"blur": labels_tensor[:, 4], "blur": labels_tensor[:, 4].clone(),
"expression": labels_tensor[:, 5], "expression": labels_tensor[:, 5].clone(),
"illumination": labels_tensor[:, 6], "illumination": labels_tensor[:, 6].clone(),
"occlusion": labels_tensor[:, 7], "occlusion": labels_tensor[:, 7].clone(),
"pose": labels_tensor[:, 8], "pose": labels_tensor[:, 8].clone(),
"invalid": labels_tensor[:, 9], "invalid": labels_tensor[:, 9].clone(),
}, },
} }
) )
......
import ctypes
import os import os
import sys import sys
from warnings import warn
import torch import torch
...@@ -22,7 +20,7 @@ try: ...@@ -22,7 +20,7 @@ try:
# conda environment/bin path is configured Please take a look: # conda environment/bin path is configured Please take a look:
# https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
# Please note: if some path can't be added using add_dll_directory we simply ignore this path # Please note: if some path can't be added using add_dll_directory we simply ignore this path
if os.name == "nt" and sys.version_info >= (3, 8) and sys.version_info < (3, 9): if os.name == "nt" and sys.version_info < (3, 9):
env_path = os.environ["PATH"] env_path = os.environ["PATH"]
path_arr = env_path.split(";") path_arr = env_path.split(";")
for path in path_arr: for path in path_arr:
...@@ -76,9 +74,9 @@ def _check_cuda_version(): ...@@ -76,9 +74,9 @@ def _check_cuda_version():
t_version = torch_version_cuda.split(".") t_version = torch_version_cuda.split(".")
t_major = int(t_version[0]) t_major = int(t_version[0])
t_minor = int(t_version[1]) t_minor = int(t_version[1])
if t_major != tv_major or t_minor != tv_minor: if t_major != tv_major:
raise RuntimeError( raise RuntimeError(
"Detected that PyTorch and torchvision were compiled with different CUDA versions. " "Detected that PyTorch and torchvision were compiled with different CUDA major versions. "
f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has " f"PyTorch has CUDA Version={t_major}.{t_minor} and torchvision has "
f"CUDA Version={tv_major}.{tv_minor}. " f"CUDA Version={tv_major}.{tv_minor}. "
"Please reinstall the torchvision that matches your PyTorch install." "Please reinstall the torchvision that matches your PyTorch install."
...@@ -88,19 +86,6 @@ def _check_cuda_version(): ...@@ -88,19 +86,6 @@ def _check_cuda_version():
def _load_library(lib_name): def _load_library(lib_name):
lib_path = _get_extension_path(lib_name) lib_path = _get_extension_path(lib_name)
# On Windows Python-3.8+ has `os.add_dll_directory` call,
# which is called from _get_extension_path to configure dll search path
# Condition below adds a workaround for older versions by
# explicitly calling `LoadLibraryExW` with the following flags:
# - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS (0x1000)
# - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR (0x100)
if os.name == "nt" and sys.version_info < (3, 8):
_kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True)
if hasattr(_kernel32, "LoadLibraryExW"):
_kernel32.LoadLibraryExW(lib_path, None, 0x00001100)
else:
warn("LoadLibraryExW is missing in kernel32.dll")
torch.ops.load_library(lib_path) torch.ops.load_library(lib_path)
......
...@@ -8,6 +8,7 @@ try: ...@@ -8,6 +8,7 @@ try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError: except ModuleNotFoundError:
_HAS_GPU_VIDEO_DECODER = False _HAS_GPU_VIDEO_DECODER = False
from ._video_opt import ( from ._video_opt import (
_HAS_VIDEO_OPT, _HAS_VIDEO_OPT,
_probe_video_from_file, _probe_video_from_file,
......
...@@ -137,8 +137,7 @@ def _read_video_from_file( ...@@ -137,8 +137,7 @@ def _read_video_from_file(
audio_timebase: Fraction = default_timebase, audio_timebase: Fraction = default_timebase,
) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]: ) -> Tuple[torch.Tensor, torch.Tensor, VideoMetaData]:
""" """
Reads a video from a file, returning both the video frames as well as Reads a video from a file, returning both the video frames and the audio frames
the audio frames
Args: Args:
filename (str): path to the video file filename (str): path to the video file
...@@ -281,8 +280,7 @@ def _read_video_from_memory( ...@@ -281,8 +280,7 @@ def _read_video_from_memory(
audio_timebase_denominator: int = 1, audio_timebase_denominator: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Reads a video from memory, returning both the video frames as well as Reads a video from memory, returning both the video frames as the audio frames
the audio frames
This function is torchscriptable. This function is torchscriptable.
Args: Args:
...@@ -336,7 +334,10 @@ def _read_video_from_memory( ...@@ -336,7 +334,10 @@ def _read_video_from_memory(
_validate_pts(audio_pts_range) _validate_pts(audio_pts_range)
if not isinstance(video_data, torch.Tensor): if not isinstance(video_data, torch.Tensor):
video_data = torch.frombuffer(video_data, dtype=torch.uint8) with warnings.catch_warnings():
# Ignore the warning because we actually don't modify the buffer in this function
warnings.filterwarnings("ignore", message="The given buffer is not writable")
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
result = torch.ops.video_reader.read_video_from_memory( result = torch.ops.video_reader.read_video_from_memory(
video_data, video_data,
...@@ -378,7 +379,10 @@ def _read_video_timestamps_from_memory( ...@@ -378,7 +379,10 @@ def _read_video_timestamps_from_memory(
is much faster than read_video(...) is much faster than read_video(...)
""" """
if not isinstance(video_data, torch.Tensor): if not isinstance(video_data, torch.Tensor):
video_data = torch.frombuffer(video_data, dtype=torch.uint8) with warnings.catch_warnings():
# Ignore the warning because we actually don't modify the buffer in this function
warnings.filterwarnings("ignore", message="The given buffer is not writable")
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
result = torch.ops.video_reader.read_video_from_memory( result = torch.ops.video_reader.read_video_from_memory(
video_data, video_data,
0, # seek_frame_margin 0, # seek_frame_margin
...@@ -416,7 +420,10 @@ def _probe_video_from_memory( ...@@ -416,7 +420,10 @@ def _probe_video_from_memory(
This function is torchscriptable This function is torchscriptable
""" """
if not isinstance(video_data, torch.Tensor): if not isinstance(video_data, torch.Tensor):
video_data = torch.frombuffer(video_data, dtype=torch.uint8) with warnings.catch_warnings():
# Ignore the warning because we actually don't modify the buffer in this function
warnings.filterwarnings("ignore", message="The given buffer is not writable")
video_data = torch.frombuffer(video_data, dtype=torch.uint8)
result = torch.ops.video_reader.probe_video_from_memory(video_data) result = torch.ops.video_reader.probe_video_from_memory(video_data)
vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result vtimebase, vfps, vduration, atimebase, asample_rate, aduration = result
info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration) info = _fill_info(vtimebase, vfps, vduration, atimebase, asample_rate, aduration)
......
...@@ -10,7 +10,12 @@ from ..utils import _log_api_usage_once ...@@ -10,7 +10,12 @@ from ..utils import _log_api_usage_once
try: try:
_load_library("image") _load_library("image")
except (ImportError, OSError) as e: except (ImportError, OSError) as e:
warn(f"Failed to load image Python extension: {e}") warn(
f"Failed to load image Python extension: '{e}'"
f"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
f"Otherwise, there might be something wrong with your environment. "
f"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
)
class ImageReadMode(Enum): class ImageReadMode(Enum):
...@@ -50,7 +55,7 @@ def read_file(path: str) -> torch.Tensor: ...@@ -50,7 +55,7 @@ def read_file(path: str) -> torch.Tensor:
def write_file(filename: str, data: torch.Tensor) -> None: def write_file(filename: str, data: torch.Tensor) -> None:
""" """
Writes the contents of a uint8 tensor with one dimension to a Writes the contents of an uint8 tensor with one dimension to a
file. file.
Args: Args:
......
...@@ -12,7 +12,6 @@ import torch ...@@ -12,7 +12,6 @@ import torch
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
from . import _video_opt from . import _video_opt
try: try:
import av import av
...@@ -242,8 +241,7 @@ def read_video( ...@@ -242,8 +241,7 @@ def read_video(
output_format: str = "THWC", output_format: str = "THWC",
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]:
""" """
Reads a video from a file, returning both the video frames as well as Reads a video from a file, returning both the video frames and the audio frames
the audio frames
Args: Args:
filename (str): path to the video file filename (str): path to the video file
......
from typing import Any, Dict, Iterator import io
import warnings
from typing import Any, Dict, Iterator, Optional
import torch import torch
from ..utils import _log_api_usage_once from ..utils import _log_api_usage_once
try:
from ._load_gpu_decoder import _HAS_GPU_VIDEO_DECODER
except ModuleNotFoundError:
_HAS_GPU_VIDEO_DECODER = False
from ._video_opt import _HAS_VIDEO_OPT from ._video_opt import _HAS_VIDEO_OPT
if _HAS_VIDEO_OPT: if _HAS_VIDEO_OPT:
...@@ -21,11 +20,37 @@ else: ...@@ -21,11 +20,37 @@ else:
return False return False
try:
import av
av.logging.set_level(av.logging.ERROR)
if not hasattr(av.video.frame.VideoFrame, "pict_type"):
av = ImportError(
"""\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date). See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
except ImportError:
av = ImportError(
"""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
class VideoReader: class VideoReader:
""" """
Fine-grained video-reading API. Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video Supports frame-by-frame reading of various streams from a single video
container. container. Much like previous video_reader API it supports the following
backends: video_reader, pyav, and cuda.
Backends can be set via `torchvision.set_video_backend` function.
.. betastatus:: VideoReader class .. betastatus:: VideoReader class
...@@ -66,13 +91,18 @@ class VideoReader: ...@@ -66,13 +91,18 @@ class VideoReader:
Each stream descriptor consists of two parts: stream type (e.g. 'video') and Each stream descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by the video encoding). a unique stream id (which are determined by the video encoding).
In this way, if the video contaner contains multiple In this way, if the video container contains multiple
streams of the same type, users can acces the one they want. streams of the same type, users can access the one they want.
If only stream type is passed, the decoder auto-detects first stream of that type. If only stream type is passed, the decoder auto-detects first stream of that type.
Args: Args:
src (string, bytes object, or tensor): The media source.
If string-type, it must be a file path supported by FFMPEG.
If bytes, should be an in-memory representation of a file supported by FFMPEG.
If Tensor, it is interpreted internally as byte buffer.
It must be one-dimensional, of type ``torch.uint8``.
path (string): Path to the video file in supported format
stream (string, optional): descriptor of the required stream, followed by the stream id, stream (string, optional): descriptor of the required stream, followed by the stream id,
in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``. in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
...@@ -82,30 +112,73 @@ class VideoReader: ...@@ -82,30 +112,73 @@ class VideoReader:
Default value (0) enables multithreading with codec-dependent heuristic. The performance Default value (0) enables multithreading with codec-dependent heuristic. The performance
will depend on the version of FFMPEG codecs supported. will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
path (str, optional):
.. warning:
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
Please use ``src`` instead.
""" """
def __init__(self, path: str, stream: str = "video", num_threads: int = 0, device: str = "cpu") -> None: def __init__(
self,
src: str = "",
stream: str = "video",
num_threads: int = 0,
path: Optional[str] = None,
) -> None:
_log_api_usage_once(self) _log_api_usage_once(self)
self.is_cuda = False from .. import get_video_backend
device = torch.device(device)
if device.type == "cuda": self.backend = get_video_backend()
if not _HAS_GPU_VIDEO_DECODER: if isinstance(src, str):
raise RuntimeError("Not compiled with GPU decoder support.") if src == "":
self.is_cuda = True if path is None:
self._c = torch.classes.torchvision.GPUDecoder(path, device) raise TypeError("src cannot be empty")
return src = path
if not _has_video_opt(): warnings.warn("path is deprecated and will be removed in 0.17. Please use src instead")
raise RuntimeError( elif isinstance(src, bytes):
"Not compiled with video_reader support, " if self.backend in ["cuda"]:
+ "to enable video_reader support, please install " raise RuntimeError(
+ "ffmpeg (version 4.2 is currently supported) and " "VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
+ "build torchvision from source." )
) elif self.backend == "pyav":
src = io.BytesIO(src)
self._c = torch.classes.torchvision.Video(path, stream, num_threads) else:
with warnings.catch_warnings():
# Ignore the warning because we actually don't modify the buffer in this function
warnings.filterwarnings("ignore", message="The given buffer is not writable")
src = torch.frombuffer(src, dtype=torch.uint8)
elif isinstance(src, torch.Tensor):
if self.backend in ["cuda", "pyav"]:
raise RuntimeError(
"VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
)
else:
raise TypeError("`src` must be either string, Tensor or bytes object.")
if self.backend == "cuda":
device = torch.device("cuda")
self._c = torch.classes.torchvision.GPUDecoder(src, device)
elif self.backend == "video_reader":
if isinstance(src, str):
self._c = torch.classes.torchvision.Video(src, stream, num_threads)
elif isinstance(src, torch.Tensor):
self._c = torch.classes.torchvision.Video("", "", 0)
self._c.init_from_memory(src, stream, num_threads)
elif self.backend == "pyav":
self.container = av.open(src, metadata_errors="ignore")
# TODO: load metadata
stream_type = stream.split(":")[0]
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
self.pyav_stream = {stream_type: stream_id}
self._c = self.container.decode(**self.pyav_stream)
# TODO: add extradata exception
else:
raise RuntimeError("Unknown video backend: {}".format(self.backend))
def __next__(self) -> Dict[str, Any]: def __next__(self) -> Dict[str, Any]:
"""Decodes and returns the next frame of the current stream. """Decodes and returns the next frame of the current stream.
...@@ -119,14 +192,29 @@ class VideoReader: ...@@ -119,14 +192,29 @@ class VideoReader:
and corresponding timestamp (``pts``) in seconds and corresponding timestamp (``pts``) in seconds
""" """
if self.is_cuda: if self.backend == "cuda":
frame = self._c.next() frame = self._c.next()
if frame.numel() == 0: if frame.numel() == 0:
raise StopIteration raise StopIteration
return {"data": frame} return {"data": frame, "pts": None}
frame, pts = self._c.next() elif self.backend == "video_reader":
frame, pts = self._c.next()
else:
try:
frame = next(self._c)
pts = float(frame.pts * frame.time_base)
if "video" in self.pyav_stream:
frame = torch.tensor(frame.to_rgb().to_ndarray()).permute(2, 0, 1)
elif "audio" in self.pyav_stream:
frame = torch.tensor(frame.to_ndarray()).permute(1, 0)
else:
frame = None
except av.error.EOFError:
raise StopIteration
if frame.numel() == 0: if frame.numel() == 0:
raise StopIteration raise StopIteration
return {"data": frame, "pts": pts} return {"data": frame, "pts": pts}
def __iter__(self) -> Iterator[Dict[str, Any]]: def __iter__(self) -> Iterator[Dict[str, Any]]:
...@@ -145,7 +233,18 @@ class VideoReader: ...@@ -145,7 +233,18 @@ class VideoReader:
frame with the exact timestamp if it exists or frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``. the first frame with timestamp larger than ``time_s``.
""" """
self._c.seek(time_s, keyframes_only) if self.backend in ["cuda", "video_reader"]:
self._c.seek(time_s, keyframes_only)
else:
# handle special case as pyav doesn't catch it
if time_s < 0:
time_s = 0
temp_str = self.container.streams.get(**self.pyav_stream)[0]
offset = int(round(time_s / temp_str.time_base))
if not keyframes_only:
warnings.warn("Accurate seek is not implemented for pyav backend")
self.container.seek(offset, backward=True, any_frame=False, stream=temp_str)
self._c = self.container.decode(**self.pyav_stream)
return self return self
def get_metadata(self) -> Dict[str, Any]: def get_metadata(self) -> Dict[str, Any]:
...@@ -154,6 +253,21 @@ class VideoReader: ...@@ -154,6 +253,21 @@ class VideoReader:
Returns: Returns:
(dict): dictionary containing duration and frame rate for every stream (dict): dictionary containing duration and frame rate for every stream
""" """
if self.backend == "pyav":
metadata = {} # type: Dict[str, Any]
for stream in self.container.streams:
if stream.type not in metadata:
if stream.type == "video":
rate_n = "fps"
else:
rate_n = "framerate"
metadata[stream.type] = {rate_n: [], "duration": []}
rate = stream.average_rate if stream.average_rate is not None else stream.sample_rate
metadata[stream.type]["duration"].append(float(stream.duration * stream.time_base))
metadata[stream.type][rate_n].append(float(rate))
return metadata
return self._c.get_metadata() return self._c.get_metadata()
def set_current_stream(self, stream: str) -> bool: def set_current_stream(self, stream: str) -> bool:
...@@ -165,14 +279,20 @@ class VideoReader: ...@@ -165,14 +279,20 @@ class VideoReader:
Currently available stream types include ``['video', 'audio']``. Currently available stream types include ``['video', 'audio']``.
Each descriptor consists of two parts: stream type (e.g. 'video') and Each descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by video encoding). a unique stream id (which are determined by video encoding).
In this way, if the video contaner contains multiple In this way, if the video container contains multiple
streams of the same type, users can acces the one they want. streams of the same type, users can access the one they want.
If only stream type is passed, the decoder auto-detects first stream If only stream type is passed, the decoder auto-detects first stream
of that type and returns it. of that type and returns it.
Returns: Returns:
(bool): True on succes, False otherwise (bool): True on success, False otherwise
""" """
if self.is_cuda: if self.backend == "cuda":
print("GPU decoding only works with video stream.") warnings.warn("GPU decoding only works with video stream.")
if self.backend == "pyav":
stream_type = stream.split(":")[0]
stream_id = 0 if len(stream.split(":")) == 1 else int(stream.split(":")[1])
self.pyav_stream = {stream_type: stream_id}
self._c = self.container.decode(**self.pyav_stream)
return True
return self._c.set_current_stream(stream) return self._c.set_current_stream(stream)
...@@ -15,4 +15,9 @@ from .vision_transformer import * ...@@ -15,4 +15,9 @@ from .vision_transformer import *
from .swin_transformer import * from .swin_transformer import *
from .maxvit import * from .maxvit import *
from . import detection, optical_flow, quantization, segmentation, video from . import detection, optical_flow, quantization, segmentation, video
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models
# The Weights and WeightsEnum are developer-facing utils that we make public for
# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
# TODO: we could / should document them publicly, but it's not clear where, as
# they're not intended for end users.
from ._api import get_model, get_model_builder, get_model_weights, get_weight, list_models, Weights, WeightsEnum
import fnmatch
import importlib import importlib
import inspect import inspect
import sys import sys
from dataclasses import dataclass, fields from dataclasses import dataclass
from enum import Enum
from functools import partial
from inspect import signature from inspect import signature
from types import ModuleType from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
from torch import nn from torch import nn
from torchvision._utils import StrEnum
from .._internally_replaced_utils import load_state_dict_from_url from .._internally_replaced_utils import load_state_dict_from_url
...@@ -37,8 +38,34 @@ class Weights: ...@@ -37,8 +38,34 @@ class Weights:
transforms: Callable transforms: Callable
meta: Dict[str, Any] meta: Dict[str, Any]
def __eq__(self, other: Any) -> bool:
class WeightsEnum(StrEnum): # We need this custom implementation for correct deep-copy and deserialization behavior.
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
# for it, the check against the defined members would fail and effectively prevent the weights from being
# deep-copied or deserialized.
# See https://github.com/pytorch/vision/pull/7107 for details.
if not isinstance(other, Weights):
return NotImplemented
if self.url != other.url:
return False
if self.meta != other.meta:
return False
if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
return (
self.transforms.func == other.transforms.func
and self.transforms.args == other.transforms.args
and self.transforms.keywords == other.transforms.keywords
)
else:
return self.transforms == other.transforms
class WeightsEnum(Enum):
""" """
This class is the parent class of all model weights. Each model building method receives an optional `weights` This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
...@@ -48,40 +75,40 @@ class WeightsEnum(StrEnum): ...@@ -48,40 +75,40 @@ class WeightsEnum(StrEnum):
value (Weights): The data class entry with the weight information. value (Weights): The data class entry with the weight information.
""" """
def __init__(self, value: Weights):
self._value_ = value
@classmethod @classmethod
def verify(cls, obj: Any) -> Any: def verify(cls, obj: Any) -> Any:
if obj is not None: if obj is not None:
if type(obj) is str: if type(obj) is str:
obj = cls.from_str(obj.replace(cls.__name__ + ".", "")) obj = cls[obj.replace(cls.__name__ + ".", "")]
elif not isinstance(obj, cls): elif not isinstance(obj, cls):
raise TypeError( raise TypeError(
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}."
) )
return obj return obj
def get_state_dict(self, progress: bool) -> Mapping[str, Any]: def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]:
return load_state_dict_from_url(self.url, progress=progress) return load_state_dict_from_url(self.url, *args, **kwargs)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self._name_}" return f"{self.__class__.__name__}.{self._name_}"
def __getattr__(self, name): @property
# Be able to fetch Weights attributes directly def url(self):
for f in fields(Weights): return self.value.url
if f.name == name:
return object.__getattribute__(self.value, name) @property
return super().__getattr__(name) def transforms(self):
return self.value.transforms
@property
def meta(self):
return self.value.meta
def get_weight(name: str) -> WeightsEnum: def get_weight(name: str) -> WeightsEnum:
""" """
Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
.. betastatus:: function
Args: Args:
name (str): The name of the weight enum entry. name (str): The name of the weight enum entry.
...@@ -96,7 +123,9 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -96,7 +123,9 @@ def get_weight(name: str) -> WeightsEnum:
base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1]) base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1])
base_module = importlib.import_module(base_module_name) base_module = importlib.import_module(base_module_name)
model_modules = [base_module] + [ model_modules = [base_module] + [
x[1] for x in inspect.getmembers(base_module, inspect.ismodule) if x[1].__file__.endswith("__init__.py") x[1]
for x in inspect.getmembers(base_module, inspect.ismodule)
if x[1].__file__.endswith("__init__.py") # type: ignore[union-attr]
] ]
weights_enum = None weights_enum = None
...@@ -109,14 +138,12 @@ def get_weight(name: str) -> WeightsEnum: ...@@ -109,14 +138,12 @@ def get_weight(name: str) -> WeightsEnum:
if weights_enum is None: if weights_enum is None:
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.")
return weights_enum.from_str(value_name) return weights_enum[value_name]
def get_model_weights(name: Union[Callable, str]) -> WeightsEnum: def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]:
""" """
Retuns the weights enum class associated to the given model. Returns the weights enum class associated to the given model.
.. betastatus:: function
Args: Args:
name (callable or str): The model builder function or the name under which it is registered. name (callable or str): The model builder function or the name under which it is registered.
...@@ -128,13 +155,12 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum: ...@@ -128,13 +155,12 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
return _get_enum_from_fn(model) return _get_enum_from_fn(model)
def _get_enum_from_fn(fn: Callable) -> WeightsEnum: def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
""" """
Internal method that gets the weight enum of a specific model builder method. Internal method that gets the weight enum of a specific model builder method.
Args: Args:
fn (Callable): The builder method used to create the model. fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
Returns: Returns:
WeightsEnum: The requested weight enum. WeightsEnum: The requested weight enum.
""" """
...@@ -159,7 +185,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum: ...@@ -159,7 +185,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." "The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
) )
return cast(WeightsEnum, weights_enum) return weights_enum
M = TypeVar("M", bound=nn.Module) M = TypeVar("M", bound=nn.Module)
...@@ -178,21 +204,43 @@ def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], C ...@@ -178,21 +204,43 @@ def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], C
return wrapper return wrapper
def list_models(module: Optional[ModuleType] = None) -> List[str]: def list_models(
module: Optional[ModuleType] = None,
include: Union[Iterable[str], str, None] = None,
exclude: Union[Iterable[str], str, None] = None,
) -> List[str]:
""" """
Returns a list with the names of registered models. Returns a list with the names of registered models.
.. betastatus:: function
Args: Args:
module (ModuleType, optional): The module from which we want to extract the available models. module (ModuleType, optional): The module from which we want to extract the available models.
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is the union of individual filters.
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
Returns: Returns:
models (list): A list with the names of available models. models (list): A list with the names of available models.
""" """
models = [ all_models = {
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__
] }
if include:
models: Set[str] = set()
if isinstance(include, str):
include = [include]
for include_filter in include:
models = models | set(fnmatch.filter(all_models, include_filter))
else:
models = all_models
if exclude:
if isinstance(exclude, str):
exclude = [exclude]
for exclude_filter in exclude:
models = models - set(fnmatch.filter(all_models, exclude_filter))
return sorted(models) return sorted(models)
...@@ -200,8 +248,6 @@ def get_model_builder(name: str) -> Callable[..., nn.Module]: ...@@ -200,8 +248,6 @@ def get_model_builder(name: str) -> Callable[..., nn.Module]:
""" """
Gets the model name and returns the model builder method. Gets the model name and returns the model builder method.
.. betastatus:: function
Args: Args:
name (str): The name under which the model is registered. name (str): The name under which the model is registered.
...@@ -220,8 +266,6 @@ def get_model(name: str, **config: Any) -> nn.Module: ...@@ -220,8 +266,6 @@ def get_model(name: str, **config: Any) -> nn.Module:
""" """
Gets the model name and configuration and returns an instantiated model. Gets the model name and configuration and returns an instantiated model.
.. betastatus:: function
Args: Args:
name (str): The name under which the model is registered. name (str): The name under which the model is registered.
**config (Any): parameters passed to the model builder method. **config (Any): parameters passed to the model builder method.
......
...@@ -191,7 +191,7 @@ def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[D ...@@ -191,7 +191,7 @@ def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[D
# used to be a pretrained parameter. # used to be a pretrained parameter.
pretrained_positional = weights_arg is not sentinel pretrained_positional = weights_arg is not sentinel
if pretrained_positional: if pretrained_positional:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have a # We put the pretrained argument under its legacy name in the keyword argument dictionary to have
# unified access to the value if the default value is a callable. # unified access to the value if the default value is a callable.
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param)
else: else:
......
...@@ -67,6 +67,8 @@ class AlexNet_Weights(WeightsEnum): ...@@ -67,6 +67,8 @@ class AlexNet_Weights(WeightsEnum):
"acc@5": 79.066, "acc@5": 79.066,
} }
}, },
"_ops": 0.714,
"_file_size": 233.087,
"_docs": """ "_docs": """
These weights reproduce closely the results of the paper using a simplified training recipe. These weights reproduce closely the results of the paper using a simplified training recipe.
""", """,
...@@ -112,17 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True, ...@@ -112,17 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
model = AlexNet(**kwargs) model = AlexNet(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"alexnet": AlexNet_Weights.IMAGENET1K_V1.url,
}
)
...@@ -189,7 +189,7 @@ def _convnext( ...@@ -189,7 +189,7 @@ def _convnext(
model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -219,6 +219,8 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): ...@@ -219,6 +219,8 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
"acc@5": 96.146, "acc@5": 96.146,
} }
}, },
"_ops": 4.456,
"_file_size": 109.119,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -237,6 +239,8 @@ class ConvNeXt_Small_Weights(WeightsEnum): ...@@ -237,6 +239,8 @@ class ConvNeXt_Small_Weights(WeightsEnum):
"acc@5": 96.650, "acc@5": 96.650,
} }
}, },
"_ops": 8.684,
"_file_size": 191.703,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -255,6 +259,8 @@ class ConvNeXt_Base_Weights(WeightsEnum): ...@@ -255,6 +259,8 @@ class ConvNeXt_Base_Weights(WeightsEnum):
"acc@5": 96.870, "acc@5": 96.870,
} }
}, },
"_ops": 15.355,
"_file_size": 338.064,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -273,6 +279,8 @@ class ConvNeXt_Large_Weights(WeightsEnum): ...@@ -273,6 +279,8 @@ class ConvNeXt_Large_Weights(WeightsEnum):
"acc@5": 96.976, "acc@5": 96.976,
} }
}, },
"_ops": 34.361,
"_file_size": 754.537,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
......
...@@ -15,7 +15,6 @@ from ._api import register_model, Weights, WeightsEnum ...@@ -15,7 +15,6 @@ from ._api import register_model, Weights, WeightsEnum
from ._meta import _IMAGENET_CATEGORIES from ._meta import _IMAGENET_CATEGORIES
from ._utils import _ovewrite_named_param, handle_legacy_interface from ._utils import _ovewrite_named_param, handle_legacy_interface
__all__ = [ __all__ = [
"DenseNet", "DenseNet",
"DenseNet121_Weights", "DenseNet121_Weights",
...@@ -228,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) -> ...@@ -228,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$" r"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
) )
state_dict = weights.get_state_dict(progress=progress) state_dict = weights.get_state_dict(progress=progress, check_hash=True)
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
res = pattern.match(key) res = pattern.match(key)
if res: if res:
...@@ -278,6 +277,8 @@ class DenseNet121_Weights(WeightsEnum): ...@@ -278,6 +277,8 @@ class DenseNet121_Weights(WeightsEnum):
"acc@5": 91.972, "acc@5": 91.972,
} }
}, },
"_ops": 2.834,
"_file_size": 30.845,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -296,6 +297,8 @@ class DenseNet161_Weights(WeightsEnum): ...@@ -296,6 +297,8 @@ class DenseNet161_Weights(WeightsEnum):
"acc@5": 93.560, "acc@5": 93.560,
} }
}, },
"_ops": 7.728,
"_file_size": 110.369,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -314,6 +317,8 @@ class DenseNet169_Weights(WeightsEnum): ...@@ -314,6 +317,8 @@ class DenseNet169_Weights(WeightsEnum):
"acc@5": 92.806, "acc@5": 92.806,
} }
}, },
"_ops": 3.36,
"_file_size": 54.708,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -332,6 +337,8 @@ class DenseNet201_Weights(WeightsEnum): ...@@ -332,6 +337,8 @@ class DenseNet201_Weights(WeightsEnum):
"acc@5": 93.370, "acc@5": 93.370,
} }
}, },
"_ops": 4.291,
"_file_size": 77.373,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -439,17 +446,3 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool ...@@ -439,17 +446,3 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool
weights = DenseNet201_Weights.verify(weights) weights = DenseNet201_Weights.verify(weights)
return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs) return _densenet(32, (6, 12, 48, 32), 64, weights, progress, **kwargs)
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"densenet121": DenseNet121_Weights.IMAGENET1K_V1.url,
"densenet169": DenseNet169_Weights.IMAGENET1K_V1.url,
"densenet201": DenseNet201_Weights.IMAGENET1K_V1.url,
"densenet161": DenseNet161_Weights.IMAGENET1K_V1.url,
}
)
...@@ -25,7 +25,7 @@ class BalancedPositiveNegativeSampler: ...@@ -25,7 +25,7 @@ class BalancedPositiveNegativeSampler:
def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
""" """
Args: Args:
matched idxs: list of tensors containing -1, 0 or positive values. matched_idxs: list of tensors containing -1, 0 or positive values.
Each tensor corresponds to a specific image. Each tensor corresponds to a specific image.
-1 values are ignored, 0 are considered as negatives and > 0 as -1 values are ignored, 0 are considered as negatives and > 0 as
positives. positives.
...@@ -403,22 +403,14 @@ class Matcher: ...@@ -403,22 +403,14 @@ class Matcher:
it is unmatched, then match it to the ground-truth with which it has the highest it is unmatched, then match it to the ground-truth with which it has the highest
quality value. quality value.
""" """
# For each gt, find the prediction with which it has highest quality # For each gt, find the prediction with which it has the highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find highest quality match available, even if it is low, including ties # Find the highest quality match available, even if it is low, including ties
gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None]) gt_pred_pairs_of_highest_quality = torch.where(match_quality_matrix == highest_quality_foreach_gt[:, None])
# Example gt_pred_pairs_of_highest_quality: # Example gt_pred_pairs_of_highest_quality:
# tensor([[ 0, 39796], # (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
# [ 1, 32055], # tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
# [ 1, 32070], # Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
# [ 2, 39190],
# [ 2, 40255],
# [ 3, 40390],
# [ 3, 41455],
# [ 4, 45470],
# [ 5, 45325],
# [ 5, 46390]])
# Each row is a (gt index, prediction index)
# Note how gt items 1, 2, 3, and 5 each have two ties # Note how gt items 1, 2, 3, and 5 each have two ties
pred_inds_to_update = gt_pred_pairs_of_highest_quality[1] pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
...@@ -501,14 +493,14 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int: ...@@ -501,14 +493,14 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
if K exceeds the number of elements along that axis. Previously, python's min() function was if K exceeds the number of elements along that axis. Previously, python's min() function was
used to determine whether to use the provided k-value or the specified dim axis value. used to determine whether to use the provided k-value or the specified dim axis value.
However in cases where the model is being exported in tracing mode, python min() is However, in cases where the model is being exported in tracing mode, python min() is
static causing the model to be traced incorrectly and eventually fail at the topk node. static causing the model to be traced incorrectly and eventually fail at the topk node.
In order to avoid this situation, in tracing mode, torch.min() is used instead. In order to avoid this situation, in tracing mode, torch.min() is used instead.
Args: Args:
input (Tensor): The orignal input tensor. input (Tensor): The original input tensor.
orig_kval (int): The provided k-value. orig_kval (int): The provided k-value.
axis(int): Axis along which we retreive the input size. axis(int): Axis along which we retrieve the input size.
Returns: Returns:
min_kval (int): Appropriately selected k-value. min_kval (int): Appropriately selected k-value.
......
...@@ -61,7 +61,7 @@ class AnchorGenerator(nn.Module): ...@@ -61,7 +61,7 @@ class AnchorGenerator(nn.Module):
aspect_ratios: List[float], aspect_ratios: List[float],
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu"), device: torch.device = torch.device("cpu"),
): ) -> Tensor:
scales = torch.as_tensor(scales, dtype=dtype, device=device) scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios) h_ratios = torch.sqrt(aspect_ratios)
...@@ -76,7 +76,7 @@ class AnchorGenerator(nn.Module): ...@@ -76,7 +76,7 @@ class AnchorGenerator(nn.Module):
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors] self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
def num_anchors_per_location(self): def num_anchors_per_location(self) -> List[int]:
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
...@@ -145,7 +145,7 @@ class DefaultBoxGenerator(nn.Module): ...@@ -145,7 +145,7 @@ class DefaultBoxGenerator(nn.Module):
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided. of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
the ``min_ratio`` and ``max_ratio`` parameters. the ``min_ratio`` and ``max_ratio`` parameters.
steps (List[int]], optional): It's a hyper-parameter that affects the tiling of defalt boxes. If not provided steps (List[int]], optional): It's a hyper-parameter that affects the tiling of default boxes. If not provided
it will be estimated from the data. it will be estimated from the data.
clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
is applied while the boxes are encoded in format ``(cx, cy, w, h)``. is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
...@@ -201,7 +201,7 @@ class DefaultBoxGenerator(nn.Module): ...@@ -201,7 +201,7 @@ class DefaultBoxGenerator(nn.Module):
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device)) _wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
return _wh_pairs return _wh_pairs
def num_anchors_per_location(self): def num_anchors_per_location(self) -> List[int]:
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map. # Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return [2 + 2 * len(r) for r in self.aspect_ratios] return [2 + 2 * len(r) for r in self.aspect_ratios]
......
...@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module): ...@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module):
@handle_legacy_interface( @handle_legacy_interface(
weights=( weights=(
"pretrained", "pretrained",
lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), lambda kwargs: _get_enum_from_fn(resnet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
), ),
) )
def resnet_fpn_backbone( def resnet_fpn_backbone(
...@@ -102,12 +102,12 @@ def resnet_fpn_backbone( ...@@ -102,12 +102,12 @@ def resnet_fpn_backbone(
trainable_layers (int): number of trainable (not frozen) layers starting from final block. trainable_layers (int): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``. returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
By default all layers are returned. By default, all layers are returned.
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
be performed. It is expected to take the fpn features, the original be performed. It is expected to take the fpn features, the original
features and the names of the original features as input, and returns features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names. By a new list of feature maps and their corresponding names. By
default a ``LastLevelMaxPool`` is used. default, a ``LastLevelMaxPool`` is used.
""" """
backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer)
return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
...@@ -121,7 +121,7 @@ def _resnet_fpn_extractor( ...@@ -121,7 +121,7 @@ def _resnet_fpn_extractor(
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> BackboneWithFPN: ) -> BackboneWithFPN:
# select layers that wont be frozen # select layers that won't be frozen
if trainable_layers < 0 or trainable_layers > 5: if trainable_layers < 0 or trainable_layers > 5:
raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}") raise ValueError(f"Trainable layers should be in the range [0,5], got {trainable_layers}")
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
...@@ -158,7 +158,7 @@ def _validate_trainable_layers( ...@@ -158,7 +158,7 @@ def _validate_trainable_layers(
if not is_trained: if not is_trained:
if trainable_backbone_layers is not None: if trainable_backbone_layers is not None:
warnings.warn( warnings.warn(
"Changing trainable_backbone_layers has not effect if " "Changing trainable_backbone_layers has no effect if "
"neither pretrained nor pretrained_backbone have been set to True, " "neither pretrained nor pretrained_backbone have been set to True, "
f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable" f"falling back to trainable_backbone_layers={max_value} so that all layers are trainable"
) )
...@@ -177,7 +177,7 @@ def _validate_trainable_layers( ...@@ -177,7 +177,7 @@ def _validate_trainable_layers(
@handle_legacy_interface( @handle_legacy_interface(
weights=( weights=(
"pretrained", "pretrained",
lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]]).from_str("IMAGENET1K_V1"), lambda kwargs: _get_enum_from_fn(mobilenet.__dict__[kwargs["backbone_name"]])["IMAGENET1K_V1"],
), ),
) )
def mobilenet_backbone( def mobilenet_backbone(
...@@ -208,7 +208,7 @@ def _mobilenet_extractor( ...@@ -208,7 +208,7 @@ def _mobilenet_extractor(
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
num_stages = len(stage_indices) num_stages = len(stage_indices)
# find the index of the layer from which we wont freeze # find the index of the layer from which we won't freeze
if trainable_layers < 0 or trainable_layers > num_stages: if trainable_layers < 0 or trainable_layers > num_stages:
raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ") raise ValueError(f"Trainable layers should be in the range [0,{num_stages}], got {trainable_layers} ")
freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]
......
...@@ -47,9 +47,9 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -47,9 +47,9 @@ class FasterRCNN(GeneralizedRCNN):
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes. image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode. The behavior of the model changes depending on if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary), During training, the model expects both the input tensors and targets (list of dictionary),
containing: containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
...@@ -68,7 +68,7 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -68,7 +68,7 @@ class FasterRCNN(GeneralizedRCNN):
Args: Args:
backbone (nn.Module): the network used to compute the features for the model. backbone (nn.Module): the network used to compute the features for the model.
It should contain a out_channels attribute, which indicates the number of output It should contain an out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps). channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or and OrderedDict[Tensor]. The backbone should return a single Tensor or and OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background). num_classes (int): number of output classes of the model (including the background).
...@@ -128,7 +128,7 @@ class FasterRCNN(GeneralizedRCNN): ...@@ -128,7 +128,7 @@ class FasterRCNN(GeneralizedRCNN):
>>> # only the features >>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> # FasterRCNN needs to know the number of >>> # FasterRCNN needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
>>> # so we need to add it here >>> # so we need to add it here
>>> backbone.out_channels = 1280 >>> backbone.out_channels = 1280
>>> >>>
...@@ -388,6 +388,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -388,6 +388,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 37.0, "box_map": 37.0,
} }
}, },
"_ops": 134.38,
"_file_size": 159.743,
"_docs": """These weights were produced by following a similar training recipe as on the paper.""", "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
}, },
) )
...@@ -407,6 +409,8 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum): ...@@ -407,6 +409,8 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
"box_map": 46.7, "box_map": 46.7,
} }
}, },
"_ops": 280.371,
"_file_size": 167.104,
"_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""", "_docs": """These weights were produced using an enhanced training recipe to boost the model accuracy.""",
}, },
) )
...@@ -426,6 +430,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): ...@@ -426,6 +430,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"box_map": 32.8, "box_map": 32.8,
} }
}, },
"_ops": 4.494,
"_file_size": 74.239,
"_docs": """These weights were produced by following a similar training recipe as on the paper.""", "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
}, },
) )
...@@ -445,6 +451,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): ...@@ -445,6 +451,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"box_map": 22.8, "box_map": 22.8,
} }
}, },
"_ops": 0.719,
"_file_size": 74.239,
"_docs": """These weights were produced by following a similar training recipe as on the paper.""", "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
}, },
) )
...@@ -475,9 +483,9 @@ def fasterrcnn_resnet50_fpn( ...@@ -475,9 +483,9 @@ def fasterrcnn_resnet50_fpn(
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes. image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode. The behavior of the model changes depending on if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary), During training, the model expects both the input tensors and a targets (list of dictionary),
containing: containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...@@ -563,7 +571,7 @@ def fasterrcnn_resnet50_fpn( ...@@ -563,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1: if weights == FasterRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
...@@ -645,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2( ...@@ -645,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -686,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ...@@ -686,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -706,7 +714,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn( ...@@ -706,7 +714,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
**kwargs: Any, **kwargs: Any,
) -> FasterRCNN: ) -> FasterRCNN:
""" """
Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tunned for mobile use cases. Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
.. betastatus:: detection module .. betastatus:: detection module
...@@ -833,16 +841,3 @@ def fasterrcnn_mobilenet_v3_large_fpn( ...@@ -833,16 +841,3 @@ def fasterrcnn_mobilenet_v3_large_fpn(
trainable_backbone_layers=trainable_backbone_layers, trainable_backbone_layers=trainable_backbone_layers,
**kwargs, **kwargs,
) )
# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
model_urls = _ModelURLs(
{
"fasterrcnn_resnet50_fpn_coco": FasterRCNN_ResNet50_FPN_Weights.COCO_V1.url,
"fasterrcnn_mobilenet_v3_large_320_fpn_coco": FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1.url,
"fasterrcnn_mobilenet_v3_large_fpn_coco": FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1.url,
}
)
...@@ -70,7 +70,7 @@ class FCOSHead(nn.Module): ...@@ -70,7 +70,7 @@ class FCOSHead(nn.Module):
else: else:
gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)] gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)]
gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)] gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)]
gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud gt_classes_targets[matched_idxs_per_image < 0] = -1 # background
all_gt_classes_targets.append(gt_classes_targets) all_gt_classes_targets.append(gt_classes_targets)
all_gt_boxes_targets.append(gt_boxes_targets) all_gt_boxes_targets.append(gt_boxes_targets)
...@@ -274,9 +274,9 @@ class FCOS(nn.Module): ...@@ -274,9 +274,9 @@ class FCOS(nn.Module):
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes. image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode. The behavior of the model changes depending on if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary), During training, the model expects both the input tensors and targets (list of dictionary),
containing: containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
...@@ -329,7 +329,7 @@ class FCOS(nn.Module): ...@@ -329,7 +329,7 @@ class FCOS(nn.Module):
>>> # only the features >>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> # FCOS needs to know the number of >>> # FCOS needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
>>> # so we need to add it here >>> # so we need to add it here
>>> backbone.out_channels = 1280 >>> backbone.out_channels = 1280
>>> >>>
...@@ -662,6 +662,8 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum): ...@@ -662,6 +662,8 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
"box_map": 39.2, "box_map": 39.2,
} }
}, },
"_ops": 128.207,
"_file_size": 123.608,
"_docs": """These weights were produced by following a similar training recipe as on the paper.""", "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
}, },
) )
...@@ -693,9 +695,9 @@ def fcos_resnet50_fpn( ...@@ -693,9 +695,9 @@ def fcos_resnet50_fpn(
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes. image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode. The behavior of the model changes depending on if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary), During training, the model expects both the input tensors and targets (list of dictionary),
containing: containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...@@ -764,17 +766,6 @@ def fcos_resnet50_fpn( ...@@ -764,17 +766,6 @@ def fcos_resnet50_fpn(
model = FCOS(backbone, num_classes, **kwargs) model = FCOS(backbone, num_classes, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
model_urls = _ModelURLs(
{
"fcos_resnet50_fpn_coco": FCOS_ResNet50_FPN_Weights.COCO_V1.url,
}
)
...@@ -29,9 +29,9 @@ class KeypointRCNN(FasterRCNN): ...@@ -29,9 +29,9 @@ class KeypointRCNN(FasterRCNN):
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes. image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode. The behavior of the model changes depending on if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary), During training, the model expects both the input tensors and targets (list of dictionary),
containing: containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...@@ -55,7 +55,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -55,7 +55,7 @@ class KeypointRCNN(FasterRCNN):
Args: Args:
backbone (nn.Module): the network used to compute the features for the model. backbone (nn.Module): the network used to compute the features for the model.
It should contain a out_channels attribute, which indicates the number of output It should contain an out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps). channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or and OrderedDict[Tensor]. The backbone should return a single Tensor or and OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background). num_classes (int): number of output classes of the model (including the background).
...@@ -121,7 +121,7 @@ class KeypointRCNN(FasterRCNN): ...@@ -121,7 +121,7 @@ class KeypointRCNN(FasterRCNN):
>>> # only the features >>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features >>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> # KeypointRCNN needs to know the number of >>> # KeypointRCNN needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280 >>> # output channels in a backbone. For mobilenet_v2, it's 1280,
>>> # so we need to add it here >>> # so we need to add it here
>>> backbone.out_channels = 1280 >>> backbone.out_channels = 1280
>>> >>>
...@@ -328,6 +328,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -328,6 +328,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"kp_map": 61.1, "kp_map": 61.1,
} }
}, },
"_ops": 133.924,
"_file_size": 226.054,
"_docs": """ "_docs": """
These weights were produced by following a similar training recipe as on the paper but use a checkpoint These weights were produced by following a similar training recipe as on the paper but use a checkpoint
from an early epoch. from an early epoch.
...@@ -347,6 +349,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -347,6 +349,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"kp_map": 65.0, "kp_map": 65.0,
} }
}, },
"_ops": 137.42,
"_file_size": 226.054,
"_docs": """These weights were produced by following a similar training recipe as on the paper.""", "_docs": """These weights were produced by following a similar training recipe as on the paper.""",
}, },
) )
...@@ -383,9 +387,9 @@ def keypointrcnn_resnet50_fpn( ...@@ -383,9 +387,9 @@ def keypointrcnn_resnet50_fpn(
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes. image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode. The behavior of the model changes depending on if it is in training or evaluation mode.
During training, the model expects both the input tensors, as well as a targets (list of dictionary), During training, the model expects both the input tensors and targets (list of dictionary),
containing: containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...@@ -461,21 +465,8 @@ def keypointrcnn_resnet50_fpn( ...@@ -461,21 +465,8 @@ def keypointrcnn_resnet50_fpn(
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1: if weights == KeypointRCNN_ResNet50_FPN_Weights.COCO_V1:
overwrite_eps(model, 0.0) overwrite_eps(model, 0.0)
return model return model
# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
model_urls = _ModelURLs(
{
# legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606
"keypointrcnn_resnet50_fpn_coco_legacy": KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY.url,
"keypointrcnn_resnet50_fpn_coco": KeypointRCNN_ResNet50_FPN_Weights.COCO_V1.url,
}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment