Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
cc26cd81
Commit
cc26cd81
authored
Nov 27, 2023
by
panning
Browse files
merge v0.16.0
parents
f78f29f5
fbb4cc54
Changes
370
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
360 additions
and
235 deletions
+360
-235
torchvision/datasets/vision.py
torchvision/datasets/vision.py
+1
-2
torchvision/datasets/widerface.py
torchvision/datasets/widerface.py
+7
-7
torchvision/extension.py
torchvision/extension.py
+3
-18
torchvision/io/__init__.py
torchvision/io/__init__.py
+1
-0
torchvision/io/_video_opt.py
torchvision/io/_video_opt.py
+14
-7
torchvision/io/image.py
torchvision/io/image.py
+7
-2
torchvision/io/video.py
torchvision/io/video.py
+1
-3
torchvision/io/video_reader.py
torchvision/io/video_reader.py
+158
-38
torchvision/models/__init__.py
torchvision/models/__init__.py
+6
-1
torchvision/models/_api.py
torchvision/models/_api.py
+82
-38
torchvision/models/_utils.py
torchvision/models/_utils.py
+1
-1
torchvision/models/alexnet.py
torchvision/models/alexnet.py
+3
-12
torchvision/models/convnext.py
torchvision/models/convnext.py
+9
-1
torchvision/models/densenet.py
torchvision/models/densenet.py
+9
-16
torchvision/models/detection/_utils.py
torchvision/models/detection/_utils.py
+9
-17
torchvision/models/detection/anchor_utils.py
torchvision/models/detection/anchor_utils.py
+4
-4
torchvision/models/detection/backbone_utils.py
torchvision/models/detection/backbone_utils.py
+7
-7
torchvision/models/detection/faster_rcnn.py
torchvision/models/detection/faster_rcnn.py
+18
-23
torchvision/models/detection/fcos.py
torchvision/models/detection/fcos.py
+9
-18
torchvision/models/detection/keypoint_rcnn.py
torchvision/models/detection/keypoint_rcnn.py
+11
-20
No files found.
Too many changes to show.
To preserve performance only
370 of 370+
files are displayed.
Plain diff
Email patch
torchvision/datasets/vision.py
View file @
cc26cd81
import
os
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch.utils.data
as
data
import
torch.utils.data
as
data
from
..utils
import
_log_api_usage_once
from
..utils
import
_log_api_usage_once
...
@@ -36,7 +35,7 @@ class VisionDataset(data.Dataset):
...
@@ -36,7 +35,7 @@ class VisionDataset(data.Dataset):
target_transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
)
->
None
:
)
->
None
:
_log_api_usage_once
(
self
)
_log_api_usage_once
(
self
)
if
isinstance
(
root
,
torch
.
_six
.
string_classes
):
if
isinstance
(
root
,
str
):
root
=
os
.
path
.
expanduser
(
root
)
root
=
os
.
path
.
expanduser
(
root
)
self
.
root
=
root
self
.
root
=
root
...
...
torchvision/datasets/widerface.py
View file @
cc26cd81
...
@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset):
...
@@ -137,13 +137,13 @@ class WIDERFace(VisionDataset):
{
{
"img_path"
:
img_path
,
"img_path"
:
img_path
,
"annotations"
:
{
"annotations"
:
{
"bbox"
:
labels_tensor
[:,
0
:
4
],
# x, y, width, height
"bbox"
:
labels_tensor
[:,
0
:
4
]
.
clone
()
,
# x, y, width, height
"blur"
:
labels_tensor
[:,
4
],
"blur"
:
labels_tensor
[:,
4
]
.
clone
()
,
"expression"
:
labels_tensor
[:,
5
],
"expression"
:
labels_tensor
[:,
5
]
.
clone
()
,
"illumination"
:
labels_tensor
[:,
6
],
"illumination"
:
labels_tensor
[:,
6
]
.
clone
()
,
"occlusion"
:
labels_tensor
[:,
7
],
"occlusion"
:
labels_tensor
[:,
7
]
.
clone
()
,
"pose"
:
labels_tensor
[:,
8
],
"pose"
:
labels_tensor
[:,
8
]
.
clone
()
,
"invalid"
:
labels_tensor
[:,
9
],
"invalid"
:
labels_tensor
[:,
9
]
.
clone
()
,
},
},
}
}
)
)
...
...
torchvision/extension.py
View file @
cc26cd81
import
ctypes
import
os
import
os
import
sys
import
sys
from
warnings
import
warn
import
torch
import
torch
...
@@ -22,7 +20,7 @@ try:
...
@@ -22,7 +20,7 @@ try:
# conda environment/bin path is configured Please take a look:
# conda environment/bin path is configured Please take a look:
# https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
# https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python
# Please note: if some path can't be added using add_dll_directory we simply ignore this path
# Please note: if some path can't be added using add_dll_directory we simply ignore this path
if
os
.
name
==
"nt"
and
sys
.
version_info
>=
(
3
,
8
)
and
sys
.
version_info
<
(
3
,
9
):
if
os
.
name
==
"nt"
and
sys
.
version_info
<
(
3
,
9
):
env_path
=
os
.
environ
[
"PATH"
]
env_path
=
os
.
environ
[
"PATH"
]
path_arr
=
env_path
.
split
(
";"
)
path_arr
=
env_path
.
split
(
";"
)
for
path
in
path_arr
:
for
path
in
path_arr
:
...
@@ -76,9 +74,9 @@ def _check_cuda_version():
...
@@ -76,9 +74,9 @@ def _check_cuda_version():
t_version
=
torch_version_cuda
.
split
(
"."
)
t_version
=
torch_version_cuda
.
split
(
"."
)
t_major
=
int
(
t_version
[
0
])
t_major
=
int
(
t_version
[
0
])
t_minor
=
int
(
t_version
[
1
])
t_minor
=
int
(
t_version
[
1
])
if
t_major
!=
tv_major
or
t_minor
!=
tv_minor
:
if
t_major
!=
tv_major
:
raise
RuntimeError
(
raise
RuntimeError
(
"Detected that PyTorch and torchvision were compiled with different CUDA versions. "
"Detected that PyTorch and torchvision were compiled with different CUDA
major
versions. "
f
"PyTorch has CUDA Version=
{
t_major
}
.
{
t_minor
}
and torchvision has "
f
"PyTorch has CUDA Version=
{
t_major
}
.
{
t_minor
}
and torchvision has "
f
"CUDA Version=
{
tv_major
}
.
{
tv_minor
}
. "
f
"CUDA Version=
{
tv_major
}
.
{
tv_minor
}
. "
"Please reinstall the torchvision that matches your PyTorch install."
"Please reinstall the torchvision that matches your PyTorch install."
...
@@ -88,19 +86,6 @@ def _check_cuda_version():
...
@@ -88,19 +86,6 @@ def _check_cuda_version():
def
_load_library
(
lib_name
):
def
_load_library
(
lib_name
):
lib_path
=
_get_extension_path
(
lib_name
)
lib_path
=
_get_extension_path
(
lib_name
)
# On Windows Python-3.8+ has `os.add_dll_directory` call,
# which is called from _get_extension_path to configure dll search path
# Condition below adds a workaround for older versions by
# explicitly calling `LoadLibraryExW` with the following flags:
# - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS (0x1000)
# - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR (0x100)
if
os
.
name
==
"nt"
and
sys
.
version_info
<
(
3
,
8
):
_kernel32
=
ctypes
.
WinDLL
(
"kernel32.dll"
,
use_last_error
=
True
)
if
hasattr
(
_kernel32
,
"LoadLibraryExW"
):
_kernel32
.
LoadLibraryExW
(
lib_path
,
None
,
0x00001100
)
else
:
warn
(
"LoadLibraryExW is missing in kernel32.dll"
)
torch
.
ops
.
load_library
(
lib_path
)
torch
.
ops
.
load_library
(
lib_path
)
...
...
torchvision/io/__init__.py
View file @
cc26cd81
...
@@ -8,6 +8,7 @@ try:
...
@@ -8,6 +8,7 @@ try:
from
._load_gpu_decoder
import
_HAS_GPU_VIDEO_DECODER
from
._load_gpu_decoder
import
_HAS_GPU_VIDEO_DECODER
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
_HAS_GPU_VIDEO_DECODER
=
False
_HAS_GPU_VIDEO_DECODER
=
False
from
._video_opt
import
(
from
._video_opt
import
(
_HAS_VIDEO_OPT
,
_HAS_VIDEO_OPT
,
_probe_video_from_file
,
_probe_video_from_file
,
...
...
torchvision/io/_video_opt.py
View file @
cc26cd81
...
@@ -137,8 +137,7 @@ def _read_video_from_file(
...
@@ -137,8 +137,7 @@ def _read_video_from_file(
audio_timebase
:
Fraction
=
default_timebase
,
audio_timebase
:
Fraction
=
default_timebase
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
VideoMetaData
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
VideoMetaData
]:
"""
"""
Reads a video from a file, returning both the video frames as well as
Reads a video from a file, returning both the video frames and the audio frames
the audio frames
Args:
Args:
filename (str): path to the video file
filename (str): path to the video file
...
@@ -281,8 +280,7 @@ def _read_video_from_memory(
...
@@ -281,8 +280,7 @@ def _read_video_from_memory(
audio_timebase_denominator
:
int
=
1
,
audio_timebase_denominator
:
int
=
1
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Reads a video from memory, returning both the video frames as well as
Reads a video from memory, returning both the video frames as the audio frames
the audio frames
This function is torchscriptable.
This function is torchscriptable.
Args:
Args:
...
@@ -336,7 +334,10 @@ def _read_video_from_memory(
...
@@ -336,7 +334,10 @@ def _read_video_from_memory(
_validate_pts
(
audio_pts_range
)
_validate_pts
(
audio_pts_range
)
if
not
isinstance
(
video_data
,
torch
.
Tensor
):
if
not
isinstance
(
video_data
,
torch
.
Tensor
):
video_data
=
torch
.
frombuffer
(
video_data
,
dtype
=
torch
.
uint8
)
with
warnings
.
catch_warnings
():
# Ignore the warning because we actually don't modify the buffer in this function
warnings
.
filterwarnings
(
"ignore"
,
message
=
"The given buffer is not writable"
)
video_data
=
torch
.
frombuffer
(
video_data
,
dtype
=
torch
.
uint8
)
result
=
torch
.
ops
.
video_reader
.
read_video_from_memory
(
result
=
torch
.
ops
.
video_reader
.
read_video_from_memory
(
video_data
,
video_data
,
...
@@ -378,7 +379,10 @@ def _read_video_timestamps_from_memory(
...
@@ -378,7 +379,10 @@ def _read_video_timestamps_from_memory(
is much faster than read_video(...)
is much faster than read_video(...)
"""
"""
if
not
isinstance
(
video_data
,
torch
.
Tensor
):
if
not
isinstance
(
video_data
,
torch
.
Tensor
):
video_data
=
torch
.
frombuffer
(
video_data
,
dtype
=
torch
.
uint8
)
with
warnings
.
catch_warnings
():
# Ignore the warning because we actually don't modify the buffer in this function
warnings
.
filterwarnings
(
"ignore"
,
message
=
"The given buffer is not writable"
)
video_data
=
torch
.
frombuffer
(
video_data
,
dtype
=
torch
.
uint8
)
result
=
torch
.
ops
.
video_reader
.
read_video_from_memory
(
result
=
torch
.
ops
.
video_reader
.
read_video_from_memory
(
video_data
,
video_data
,
0
,
# seek_frame_margin
0
,
# seek_frame_margin
...
@@ -416,7 +420,10 @@ def _probe_video_from_memory(
...
@@ -416,7 +420,10 @@ def _probe_video_from_memory(
This function is torchscriptable
This function is torchscriptable
"""
"""
if
not
isinstance
(
video_data
,
torch
.
Tensor
):
if
not
isinstance
(
video_data
,
torch
.
Tensor
):
video_data
=
torch
.
frombuffer
(
video_data
,
dtype
=
torch
.
uint8
)
with
warnings
.
catch_warnings
():
# Ignore the warning because we actually don't modify the buffer in this function
warnings
.
filterwarnings
(
"ignore"
,
message
=
"The given buffer is not writable"
)
video_data
=
torch
.
frombuffer
(
video_data
,
dtype
=
torch
.
uint8
)
result
=
torch
.
ops
.
video_reader
.
probe_video_from_memory
(
video_data
)
result
=
torch
.
ops
.
video_reader
.
probe_video_from_memory
(
video_data
)
vtimebase
,
vfps
,
vduration
,
atimebase
,
asample_rate
,
aduration
=
result
vtimebase
,
vfps
,
vduration
,
atimebase
,
asample_rate
,
aduration
=
result
info
=
_fill_info
(
vtimebase
,
vfps
,
vduration
,
atimebase
,
asample_rate
,
aduration
)
info
=
_fill_info
(
vtimebase
,
vfps
,
vduration
,
atimebase
,
asample_rate
,
aduration
)
...
...
torchvision/io/image.py
View file @
cc26cd81
...
@@ -10,7 +10,12 @@ from ..utils import _log_api_usage_once
...
@@ -10,7 +10,12 @@ from ..utils import _log_api_usage_once
try
:
try
:
_load_library
(
"image"
)
_load_library
(
"image"
)
except
(
ImportError
,
OSError
)
as
e
:
except
(
ImportError
,
OSError
)
as
e
:
warn
(
f
"Failed to load image Python extension:
{
e
}
"
)
warn
(
f
"Failed to load image Python extension: '
{
e
}
'"
f
"If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. "
f
"Otherwise, there might be something wrong with your environment. "
f
"Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?"
)
class
ImageReadMode
(
Enum
):
class
ImageReadMode
(
Enum
):
...
@@ -50,7 +55,7 @@ def read_file(path: str) -> torch.Tensor:
...
@@ -50,7 +55,7 @@ def read_file(path: str) -> torch.Tensor:
def
write_file
(
filename
:
str
,
data
:
torch
.
Tensor
)
->
None
:
def
write_file
(
filename
:
str
,
data
:
torch
.
Tensor
)
->
None
:
"""
"""
Writes the contents of a uint8 tensor with one dimension to a
Writes the contents of a
n
uint8 tensor with one dimension to a
file.
file.
Args:
Args:
...
...
torchvision/io/video.py
View file @
cc26cd81
...
@@ -12,7 +12,6 @@ import torch
...
@@ -12,7 +12,6 @@ import torch
from
..utils
import
_log_api_usage_once
from
..utils
import
_log_api_usage_once
from
.
import
_video_opt
from
.
import
_video_opt
try
:
try
:
import
av
import
av
...
@@ -242,8 +241,7 @@ def read_video(
...
@@ -242,8 +241,7 @@ def read_video(
output_format
:
str
=
"THWC"
,
output_format
:
str
=
"THWC"
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Dict
[
str
,
Any
]]:
"""
"""
Reads a video from a file, returning both the video frames as well as
Reads a video from a file, returning both the video frames and the audio frames
the audio frames
Args:
Args:
filename (str): path to the video file
filename (str): path to the video file
...
...
torchvision/io/video_reader.py
View file @
cc26cd81
from
typing
import
Any
,
Dict
,
Iterator
import
io
import
warnings
from
typing
import
Any
,
Dict
,
Iterator
,
Optional
import
torch
import
torch
from
..utils
import
_log_api_usage_once
from
..utils
import
_log_api_usage_once
try
:
from
._load_gpu_decoder
import
_HAS_GPU_VIDEO_DECODER
except
ModuleNotFoundError
:
_HAS_GPU_VIDEO_DECODER
=
False
from
._video_opt
import
_HAS_VIDEO_OPT
from
._video_opt
import
_HAS_VIDEO_OPT
if
_HAS_VIDEO_OPT
:
if
_HAS_VIDEO_OPT
:
...
@@ -21,11 +20,37 @@ else:
...
@@ -21,11 +20,37 @@ else:
return
False
return
False
try
:
import
av
av
.
logging
.
set_level
(
av
.
logging
.
ERROR
)
if
not
hasattr
(
av
.
video
.
frame
.
VideoFrame
,
"pict_type"
):
av
=
ImportError
(
"""
\
Your version of PyAV is too old for the necessary video operations in torchvision.
If you are on Python 3.5, you will have to build from source (the conda-forge
packages are not up-to-date). See
https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
except
ImportError
:
av
=
ImportError
(
"""
\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
"""
)
class
VideoReader
:
class
VideoReader
:
"""
"""
Fine-grained video-reading API.
Fine-grained video-reading API.
Supports frame-by-frame reading of various streams from a single video
Supports frame-by-frame reading of various streams from a single video
container.
container. Much like previous video_reader API it supports the following
backends: video_reader, pyav, and cuda.
Backends can be set via `torchvision.set_video_backend` function.
.. betastatus:: VideoReader class
.. betastatus:: VideoReader class
...
@@ -66,13 +91,18 @@ class VideoReader:
...
@@ -66,13 +91,18 @@ class VideoReader:
Each stream descriptor consists of two parts: stream type (e.g. 'video') and
Each stream descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by the video encoding).
a unique stream id (which are determined by the video encoding).
In this way, if the video contaner contains multiple
In this way, if the video conta
i
ner contains multiple
streams of the same type, users can acces the one they want.
streams of the same type, users can acces
s
the one they want.
If only stream type is passed, the decoder auto-detects first stream of that type.
If only stream type is passed, the decoder auto-detects first stream of that type.
Args:
Args:
src (string, bytes object, or tensor): The media source.
If string-type, it must be a file path supported by FFMPEG.
If bytes, should be an in-memory representation of a file supported by FFMPEG.
If Tensor, it is interpreted internally as byte buffer.
It must be one-dimensional, of type ``torch.uint8``.
path (string): Path to the video file in supported format
stream (string, optional): descriptor of the required stream, followed by the stream id,
stream (string, optional): descriptor of the required stream, followed by the stream id,
in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
in the format ``{stream_type}:{stream_id}``. Defaults to ``"video:0"``.
...
@@ -82,30 +112,73 @@ class VideoReader:
...
@@ -82,30 +112,73 @@ class VideoReader:
Default value (0) enables multithreading with codec-dependent heuristic. The performance
Default value (0) enables multithreading with codec-dependent heuristic. The performance
will depend on the version of FFMPEG codecs supported.
will depend on the version of FFMPEG codecs supported.
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
To use GPU decoding, pass ``device="cuda"``.
path (str, optional):
.. warning:
This parameter was deprecated in ``0.15`` and will be removed in ``0.17``.
Please use ``src`` instead.
"""
"""
def
__init__
(
self
,
path
:
str
,
stream
:
str
=
"video"
,
num_threads
:
int
=
0
,
device
:
str
=
"cpu"
)
->
None
:
def
__init__
(
self
,
src
:
str
=
""
,
stream
:
str
=
"video"
,
num_threads
:
int
=
0
,
path
:
Optional
[
str
]
=
None
,
)
->
None
:
_log_api_usage_once
(
self
)
_log_api_usage_once
(
self
)
self
.
is_cuda
=
False
from
..
import
get_video_backend
device
=
torch
.
device
(
device
)
if
device
.
type
==
"cuda"
:
self
.
backend
=
get_video_backend
()
if
not
_HAS_GPU_VIDEO_DECODER
:
if
isinstance
(
src
,
str
):
raise
RuntimeError
(
"Not compiled with GPU decoder support."
)
if
src
==
""
:
self
.
is_cuda
=
True
if
path
is
None
:
self
.
_c
=
torch
.
classes
.
torchvision
.
GPUDecoder
(
path
,
device
)
raise
TypeError
(
"src cannot be empty"
)
return
src
=
path
if
not
_has_video_opt
():
warnings
.
warn
(
"path is deprecated and will be removed in 0.17. Please use src instead"
)
raise
RuntimeError
(
elif
isinstance
(
src
,
bytes
):
"Not compiled with video_reader support, "
if
self
.
backend
in
[
"cuda"
]:
+
"to enable video_reader support, please install "
raise
RuntimeError
(
+
"ffmpeg (version 4.2 is currently supported) and "
"VideoReader cannot be initialized from bytes object when using cuda or pyav backend."
+
"build torchvision from source."
)
)
elif
self
.
backend
==
"pyav"
:
src
=
io
.
BytesIO
(
src
)
self
.
_c
=
torch
.
classes
.
torchvision
.
Video
(
path
,
stream
,
num_threads
)
else
:
with
warnings
.
catch_warnings
():
# Ignore the warning because we actually don't modify the buffer in this function
warnings
.
filterwarnings
(
"ignore"
,
message
=
"The given buffer is not writable"
)
src
=
torch
.
frombuffer
(
src
,
dtype
=
torch
.
uint8
)
elif
isinstance
(
src
,
torch
.
Tensor
):
if
self
.
backend
in
[
"cuda"
,
"pyav"
]:
raise
RuntimeError
(
"VideoReader cannot be initialized from Tensor object when using cuda or pyav backend."
)
else
:
raise
TypeError
(
"`src` must be either string, Tensor or bytes object."
)
if
self
.
backend
==
"cuda"
:
device
=
torch
.
device
(
"cuda"
)
self
.
_c
=
torch
.
classes
.
torchvision
.
GPUDecoder
(
src
,
device
)
elif
self
.
backend
==
"video_reader"
:
if
isinstance
(
src
,
str
):
self
.
_c
=
torch
.
classes
.
torchvision
.
Video
(
src
,
stream
,
num_threads
)
elif
isinstance
(
src
,
torch
.
Tensor
):
self
.
_c
=
torch
.
classes
.
torchvision
.
Video
(
""
,
""
,
0
)
self
.
_c
.
init_from_memory
(
src
,
stream
,
num_threads
)
elif
self
.
backend
==
"pyav"
:
self
.
container
=
av
.
open
(
src
,
metadata_errors
=
"ignore"
)
# TODO: load metadata
stream_type
=
stream
.
split
(
":"
)[
0
]
stream_id
=
0
if
len
(
stream
.
split
(
":"
))
==
1
else
int
(
stream
.
split
(
":"
)[
1
])
self
.
pyav_stream
=
{
stream_type
:
stream_id
}
self
.
_c
=
self
.
container
.
decode
(
**
self
.
pyav_stream
)
# TODO: add extradata exception
else
:
raise
RuntimeError
(
"Unknown video backend: {}"
.
format
(
self
.
backend
))
def
__next__
(
self
)
->
Dict
[
str
,
Any
]:
def
__next__
(
self
)
->
Dict
[
str
,
Any
]:
"""Decodes and returns the next frame of the current stream.
"""Decodes and returns the next frame of the current stream.
...
@@ -119,14 +192,29 @@ class VideoReader:
...
@@ -119,14 +192,29 @@ class VideoReader:
and corresponding timestamp (``pts``) in seconds
and corresponding timestamp (``pts``) in seconds
"""
"""
if
self
.
is_
cuda
:
if
self
.
backend
==
"
cuda
"
:
frame
=
self
.
_c
.
next
()
frame
=
self
.
_c
.
next
()
if
frame
.
numel
()
==
0
:
if
frame
.
numel
()
==
0
:
raise
StopIteration
raise
StopIteration
return
{
"data"
:
frame
}
return
{
"data"
:
frame
,
"pts"
:
None
}
frame
,
pts
=
self
.
_c
.
next
()
elif
self
.
backend
==
"video_reader"
:
frame
,
pts
=
self
.
_c
.
next
()
else
:
try
:
frame
=
next
(
self
.
_c
)
pts
=
float
(
frame
.
pts
*
frame
.
time_base
)
if
"video"
in
self
.
pyav_stream
:
frame
=
torch
.
tensor
(
frame
.
to_rgb
().
to_ndarray
()).
permute
(
2
,
0
,
1
)
elif
"audio"
in
self
.
pyav_stream
:
frame
=
torch
.
tensor
(
frame
.
to_ndarray
()).
permute
(
1
,
0
)
else
:
frame
=
None
except
av
.
error
.
EOFError
:
raise
StopIteration
if
frame
.
numel
()
==
0
:
if
frame
.
numel
()
==
0
:
raise
StopIteration
raise
StopIteration
return
{
"data"
:
frame
,
"pts"
:
pts
}
return
{
"data"
:
frame
,
"pts"
:
pts
}
def
__iter__
(
self
)
->
Iterator
[
Dict
[
str
,
Any
]]:
def
__iter__
(
self
)
->
Iterator
[
Dict
[
str
,
Any
]]:
...
@@ -145,7 +233,18 @@ class VideoReader:
...
@@ -145,7 +233,18 @@ class VideoReader:
frame with the exact timestamp if it exists or
frame with the exact timestamp if it exists or
the first frame with timestamp larger than ``time_s``.
the first frame with timestamp larger than ``time_s``.
"""
"""
self
.
_c
.
seek
(
time_s
,
keyframes_only
)
if
self
.
backend
in
[
"cuda"
,
"video_reader"
]:
self
.
_c
.
seek
(
time_s
,
keyframes_only
)
else
:
# handle special case as pyav doesn't catch it
if
time_s
<
0
:
time_s
=
0
temp_str
=
self
.
container
.
streams
.
get
(
**
self
.
pyav_stream
)[
0
]
offset
=
int
(
round
(
time_s
/
temp_str
.
time_base
))
if
not
keyframes_only
:
warnings
.
warn
(
"Accurate seek is not implemented for pyav backend"
)
self
.
container
.
seek
(
offset
,
backward
=
True
,
any_frame
=
False
,
stream
=
temp_str
)
self
.
_c
=
self
.
container
.
decode
(
**
self
.
pyav_stream
)
return
self
return
self
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
...
@@ -154,6 +253,21 @@ class VideoReader:
...
@@ -154,6 +253,21 @@ class VideoReader:
Returns:
Returns:
(dict): dictionary containing duration and frame rate for every stream
(dict): dictionary containing duration and frame rate for every stream
"""
"""
if
self
.
backend
==
"pyav"
:
metadata
=
{}
# type: Dict[str, Any]
for
stream
in
self
.
container
.
streams
:
if
stream
.
type
not
in
metadata
:
if
stream
.
type
==
"video"
:
rate_n
=
"fps"
else
:
rate_n
=
"framerate"
metadata
[
stream
.
type
]
=
{
rate_n
:
[],
"duration"
:
[]}
rate
=
stream
.
average_rate
if
stream
.
average_rate
is
not
None
else
stream
.
sample_rate
metadata
[
stream
.
type
][
"duration"
].
append
(
float
(
stream
.
duration
*
stream
.
time_base
))
metadata
[
stream
.
type
][
rate_n
].
append
(
float
(
rate
))
return
metadata
return
self
.
_c
.
get_metadata
()
return
self
.
_c
.
get_metadata
()
def
set_current_stream
(
self
,
stream
:
str
)
->
bool
:
def
set_current_stream
(
self
,
stream
:
str
)
->
bool
:
...
@@ -165,14 +279,20 @@ class VideoReader:
...
@@ -165,14 +279,20 @@ class VideoReader:
Currently available stream types include ``['video', 'audio']``.
Currently available stream types include ``['video', 'audio']``.
Each descriptor consists of two parts: stream type (e.g. 'video') and
Each descriptor consists of two parts: stream type (e.g. 'video') and
a unique stream id (which are determined by video encoding).
a unique stream id (which are determined by video encoding).
In this way, if the video contaner contains multiple
In this way, if the video conta
i
ner contains multiple
streams of the same type, users can acces the one they want.
streams of the same type, users can acces
s
the one they want.
If only stream type is passed, the decoder auto-detects first stream
If only stream type is passed, the decoder auto-detects first stream
of that type and returns it.
of that type and returns it.
Returns:
Returns:
(bool): True on succes, False otherwise
(bool): True on succes
s
, False otherwise
"""
"""
if
self
.
is_cuda
:
if
self
.
backend
==
"cuda"
:
print
(
"GPU decoding only works with video stream."
)
warnings
.
warn
(
"GPU decoding only works with video stream."
)
if
self
.
backend
==
"pyav"
:
stream_type
=
stream
.
split
(
":"
)[
0
]
stream_id
=
0
if
len
(
stream
.
split
(
":"
))
==
1
else
int
(
stream
.
split
(
":"
)[
1
])
self
.
pyav_stream
=
{
stream_type
:
stream_id
}
self
.
_c
=
self
.
container
.
decode
(
**
self
.
pyav_stream
)
return
True
return
self
.
_c
.
set_current_stream
(
stream
)
return
self
.
_c
.
set_current_stream
(
stream
)
torchvision/models/__init__.py
View file @
cc26cd81
...
@@ -15,4 +15,9 @@ from .vision_transformer import *
...
@@ -15,4 +15,9 @@ from .vision_transformer import *
from
.swin_transformer
import
*
from
.swin_transformer
import
*
from
.maxvit
import
*
from
.maxvit
import
*
from
.
import
detection
,
optical_flow
,
quantization
,
segmentation
,
video
from
.
import
detection
,
optical_flow
,
quantization
,
segmentation
,
video
from
._api
import
get_model
,
get_model_builder
,
get_model_weights
,
get_weight
,
list_models
# The Weights and WeightsEnum are developer-facing utils that we make public for
# downstream libs like torchgeo https://github.com/pytorch/vision/issues/7094
# TODO: we could / should document them publicly, but it's not clear where, as
# they're not intended for end users.
from
._api
import
get_model
,
get_model_builder
,
get_model_weights
,
get_weight
,
list_models
,
Weights
,
WeightsEnum
torchvision/models/_api.py
View file @
cc26cd81
import
fnmatch
import
importlib
import
importlib
import
inspect
import
inspect
import
sys
import
sys
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
from
enum
import
Enum
from
functools
import
partial
from
inspect
import
signature
from
inspect
import
signature
from
types
import
ModuleType
from
types
import
ModuleType
from
typing
import
Any
,
Callable
,
cast
,
Dict
,
List
,
Mapping
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Type
,
TypeVar
,
Union
from
torch
import
nn
from
torch
import
nn
from
torchvision._utils
import
StrEnum
from
.._internally_replaced_utils
import
load_state_dict_from_url
from
.._internally_replaced_utils
import
load_state_dict_from_url
...
@@ -37,8 +38,34 @@ class Weights:
...
@@ -37,8 +38,34 @@ class Weights:
transforms
:
Callable
transforms
:
Callable
meta
:
Dict
[
str
,
Any
]
meta
:
Dict
[
str
,
Any
]
def
__eq__
(
self
,
other
:
Any
)
->
bool
:
class
WeightsEnum
(
StrEnum
):
# We need this custom implementation for correct deep-copy and deserialization behavior.
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
# for it, the check against the defined members would fail and effectively prevent the weights from being
# deep-copied or deserialized.
# See https://github.com/pytorch/vision/pull/7107 for details.
if
not
isinstance
(
other
,
Weights
):
return
NotImplemented
if
self
.
url
!=
other
.
url
:
return
False
if
self
.
meta
!=
other
.
meta
:
return
False
if
isinstance
(
self
.
transforms
,
partial
)
and
isinstance
(
other
.
transforms
,
partial
):
return
(
self
.
transforms
.
func
==
other
.
transforms
.
func
and
self
.
transforms
.
args
==
other
.
transforms
.
args
and
self
.
transforms
.
keywords
==
other
.
transforms
.
keywords
)
else
:
return
self
.
transforms
==
other
.
transforms
class
WeightsEnum
(
Enum
):
"""
"""
This class is the parent class of all model weights. Each model building method receives an optional `weights`
This class is the parent class of all model weights. Each model building method receives an optional `weights`
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type
...
@@ -48,40 +75,40 @@ class WeightsEnum(StrEnum):
...
@@ -48,40 +75,40 @@ class WeightsEnum(StrEnum):
value (Weights): The data class entry with the weight information.
value (Weights): The data class entry with the weight information.
"""
"""
def
__init__
(
self
,
value
:
Weights
):
self
.
_value_
=
value
@
classmethod
@
classmethod
def
verify
(
cls
,
obj
:
Any
)
->
Any
:
def
verify
(
cls
,
obj
:
Any
)
->
Any
:
if
obj
is
not
None
:
if
obj
is
not
None
:
if
type
(
obj
)
is
str
:
if
type
(
obj
)
is
str
:
obj
=
cls
.
from_str
(
obj
.
replace
(
cls
.
__name__
+
"."
,
""
)
)
obj
=
cls
[
obj
.
replace
(
cls
.
__name__
+
"."
,
""
)
]
elif
not
isinstance
(
obj
,
cls
):
elif
not
isinstance
(
obj
,
cls
):
raise
TypeError
(
raise
TypeError
(
f
"Invalid Weight class provided; expected
{
cls
.
__name__
}
but received
{
obj
.
__class__
.
__name__
}
."
f
"Invalid Weight class provided; expected
{
cls
.
__name__
}
but received
{
obj
.
__class__
.
__name__
}
."
)
)
return
obj
return
obj
def
get_state_dict
(
self
,
progress
:
bool
)
->
Mapping
[
str
,
Any
]:
def
get_state_dict
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Mapping
[
str
,
Any
]:
return
load_state_dict_from_url
(
self
.
url
,
progress
=
progres
s
)
return
load_state_dict_from_url
(
self
.
url
,
*
args
,
**
kwarg
s
)
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
.
{
self
.
_name_
}
"
return
f
"
{
self
.
__class__
.
__name__
}
.
{
self
.
_name_
}
"
def
__getattr__
(
self
,
name
):
@
property
# Be able to fetch Weights attributes directly
def
url
(
self
):
for
f
in
fields
(
Weights
):
return
self
.
value
.
url
if
f
.
name
==
name
:
return
object
.
__getattribute__
(
self
.
value
,
name
)
@
property
return
super
().
__getattr__
(
name
)
def
transforms
(
self
):
return
self
.
value
.
transforms
@
property
def
meta
(
self
):
return
self
.
value
.
meta
def
get_weight
(
name
:
str
)
->
WeightsEnum
:
def
get_weight
(
name
:
str
)
->
WeightsEnum
:
"""
"""
Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1"
.. betastatus:: function
Args:
Args:
name (str): The name of the weight enum entry.
name (str): The name of the weight enum entry.
...
@@ -96,7 +123,9 @@ def get_weight(name: str) -> WeightsEnum:
...
@@ -96,7 +123,9 @@ def get_weight(name: str) -> WeightsEnum:
base_module_name
=
"."
.
join
(
sys
.
modules
[
__name__
].
__name__
.
split
(
"."
)[:
-
1
])
base_module_name
=
"."
.
join
(
sys
.
modules
[
__name__
].
__name__
.
split
(
"."
)[:
-
1
])
base_module
=
importlib
.
import_module
(
base_module_name
)
base_module
=
importlib
.
import_module
(
base_module_name
)
model_modules
=
[
base_module
]
+
[
model_modules
=
[
base_module
]
+
[
x
[
1
]
for
x
in
inspect
.
getmembers
(
base_module
,
inspect
.
ismodule
)
if
x
[
1
].
__file__
.
endswith
(
"__init__.py"
)
x
[
1
]
for
x
in
inspect
.
getmembers
(
base_module
,
inspect
.
ismodule
)
if
x
[
1
].
__file__
.
endswith
(
"__init__.py"
)
# type: ignore[union-attr]
]
]
weights_enum
=
None
weights_enum
=
None
...
@@ -109,14 +138,12 @@ def get_weight(name: str) -> WeightsEnum:
...
@@ -109,14 +138,12 @@ def get_weight(name: str) -> WeightsEnum:
if
weights_enum
is
None
:
if
weights_enum
is
None
:
raise
ValueError
(
f
"The weight enum '
{
enum_name
}
' for the specific method couldn't be retrieved."
)
raise
ValueError
(
f
"The weight enum '
{
enum_name
}
' for the specific method couldn't be retrieved."
)
return
weights_enum
.
from_str
(
value_name
)
return
weights_enum
[
value_name
]
def
get_model_weights
(
name
:
Union
[
Callable
,
str
])
->
WeightsEnum
:
def
get_model_weights
(
name
:
Union
[
Callable
,
str
])
->
Type
[
WeightsEnum
]
:
"""
"""
Retuns the weights enum class associated to the given model.
Returns the weights enum class associated to the given model.
.. betastatus:: function
Args:
Args:
name (callable or str): The model builder function or the name under which it is registered.
name (callable or str): The model builder function or the name under which it is registered.
...
@@ -128,13 +155,12 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
...
@@ -128,13 +155,12 @@ def get_model_weights(name: Union[Callable, str]) -> WeightsEnum:
return
_get_enum_from_fn
(
model
)
return
_get_enum_from_fn
(
model
)
def
_get_enum_from_fn
(
fn
:
Callable
)
->
WeightsEnum
:
def
_get_enum_from_fn
(
fn
:
Callable
)
->
Type
[
WeightsEnum
]
:
"""
"""
Internal method that gets the weight enum of a specific model builder method.
Internal method that gets the weight enum of a specific model builder method.
Args:
Args:
fn (Callable): The builder method used to create the model.
fn (Callable): The builder method used to create the model.
weight_name (str): The name of the weight enum entry of the specific model.
Returns:
Returns:
WeightsEnum: The requested weight enum.
WeightsEnum: The requested weight enum.
"""
"""
...
@@ -159,7 +185,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
...
@@ -159,7 +185,7 @@ def _get_enum_from_fn(fn: Callable) -> WeightsEnum:
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct."
)
)
return
cast
(
WeightsEnum
,
weights_enum
)
return
weights_enum
M
=
TypeVar
(
"M"
,
bound
=
nn
.
Module
)
M
=
TypeVar
(
"M"
,
bound
=
nn
.
Module
)
...
@@ -178,21 +204,43 @@ def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], C
...
@@ -178,21 +204,43 @@ def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], C
return
wrapper
return
wrapper
def
list_models
(
module
:
Optional
[
ModuleType
]
=
None
)
->
List
[
str
]:
def
list_models
(
module
:
Optional
[
ModuleType
]
=
None
,
include
:
Union
[
Iterable
[
str
],
str
,
None
]
=
None
,
exclude
:
Union
[
Iterable
[
str
],
str
,
None
]
=
None
,
)
->
List
[
str
]:
"""
"""
Returns a list with the names of registered models.
Returns a list with the names of registered models.
.. betastatus:: function
Args:
Args:
module (ModuleType, optional): The module from which we want to extract the available models.
module (ModuleType, optional): The module from which we want to extract the available models.
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models.
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is the union of individual filters.
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models.
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style
wildcards. In case of many filters, the results is removal of all the models that match any individual filter.
Returns:
Returns:
models (list): A list with the names of available models.
models (list): A list with the names of available models.
"""
"""
models
=
[
all_
models
=
{
k
for
k
,
v
in
BUILTIN_MODELS
.
items
()
if
module
is
None
or
v
.
__module__
.
rsplit
(
"."
,
1
)[
0
]
==
module
.
__name__
k
for
k
,
v
in
BUILTIN_MODELS
.
items
()
if
module
is
None
or
v
.
__module__
.
rsplit
(
"."
,
1
)[
0
]
==
module
.
__name__
]
}
if
include
:
models
:
Set
[
str
]
=
set
()
if
isinstance
(
include
,
str
):
include
=
[
include
]
for
include_filter
in
include
:
models
=
models
|
set
(
fnmatch
.
filter
(
all_models
,
include_filter
))
else
:
models
=
all_models
if
exclude
:
if
isinstance
(
exclude
,
str
):
exclude
=
[
exclude
]
for
exclude_filter
in
exclude
:
models
=
models
-
set
(
fnmatch
.
filter
(
all_models
,
exclude_filter
))
return
sorted
(
models
)
return
sorted
(
models
)
...
@@ -200,8 +248,6 @@ def get_model_builder(name: str) -> Callable[..., nn.Module]:
...
@@ -200,8 +248,6 @@ def get_model_builder(name: str) -> Callable[..., nn.Module]:
"""
"""
Gets the model name and returns the model builder method.
Gets the model name and returns the model builder method.
.. betastatus:: function
Args:
Args:
name (str): The name under which the model is registered.
name (str): The name under which the model is registered.
...
@@ -220,8 +266,6 @@ def get_model(name: str, **config: Any) -> nn.Module:
...
@@ -220,8 +266,6 @@ def get_model(name: str, **config: Any) -> nn.Module:
"""
"""
Gets the model name and configuration and returns an instantiated model.
Gets the model name and configuration and returns an instantiated model.
.. betastatus:: function
Args:
Args:
name (str): The name under which the model is registered.
name (str): The name under which the model is registered.
**config (Any): parameters passed to the model builder method.
**config (Any): parameters passed to the model builder method.
...
...
torchvision/models/_utils.py
View file @
cc26cd81
...
@@ -191,7 +191,7 @@ def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[D
...
@@ -191,7 +191,7 @@ def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[D
# used to be a pretrained parameter.
# used to be a pretrained parameter.
pretrained_positional
=
weights_arg
is
not
sentinel
pretrained_positional
=
weights_arg
is
not
sentinel
if
pretrained_positional
:
if
pretrained_positional
:
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have
a
# We put the pretrained argument under its legacy name in the keyword argument dictionary to have
# unified access to the value if the default value is a callable.
# unified access to the value if the default value is a callable.
kwargs
[
pretrained_param
]
=
pretrained_arg
=
kwargs
.
pop
(
weights_param
)
kwargs
[
pretrained_param
]
=
pretrained_arg
=
kwargs
.
pop
(
weights_param
)
else
:
else
:
...
...
torchvision/models/alexnet.py
View file @
cc26cd81
...
@@ -67,6 +67,8 @@ class AlexNet_Weights(WeightsEnum):
...
@@ -67,6 +67,8 @@ class AlexNet_Weights(WeightsEnum):
"acc@5"
:
79.066
,
"acc@5"
:
79.066
,
}
}
},
},
"_ops"
:
0.714
,
"_file_size"
:
233.087
,
"_docs"
:
"""
"_docs"
:
"""
These weights reproduce closely the results of the paper using a simplified training recipe.
These weights reproduce closely the results of the paper using a simplified training recipe.
"""
,
"""
,
...
@@ -112,17 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
...
@@ -112,17 +114,6 @@ def alexnet(*, weights: Optional[AlexNet_Weights] = None, progress: bool = True,
model
=
AlexNet
(
**
kwargs
)
model
=
AlexNet
(
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"alexnet"
:
AlexNet_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/convnext.py
View file @
cc26cd81
...
@@ -189,7 +189,7 @@ def _convnext(
...
@@ -189,7 +189,7 @@ def _convnext(
model
=
ConvNeXt
(
block_setting
,
stochastic_depth_prob
=
stochastic_depth_prob
,
**
kwargs
)
model
=
ConvNeXt
(
block_setting
,
stochastic_depth_prob
=
stochastic_depth_prob
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -219,6 +219,8 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
...
@@ -219,6 +219,8 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
"acc@5"
:
96.146
,
"acc@5"
:
96.146
,
}
}
},
},
"_ops"
:
4.456
,
"_file_size"
:
109.119
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -237,6 +239,8 @@ class ConvNeXt_Small_Weights(WeightsEnum):
...
@@ -237,6 +239,8 @@ class ConvNeXt_Small_Weights(WeightsEnum):
"acc@5"
:
96.650
,
"acc@5"
:
96.650
,
}
}
},
},
"_ops"
:
8.684
,
"_file_size"
:
191.703
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -255,6 +259,8 @@ class ConvNeXt_Base_Weights(WeightsEnum):
...
@@ -255,6 +259,8 @@ class ConvNeXt_Base_Weights(WeightsEnum):
"acc@5"
:
96.870
,
"acc@5"
:
96.870
,
}
}
},
},
"_ops"
:
15.355
,
"_file_size"
:
338.064
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -273,6 +279,8 @@ class ConvNeXt_Large_Weights(WeightsEnum):
...
@@ -273,6 +279,8 @@ class ConvNeXt_Large_Weights(WeightsEnum):
"acc@5"
:
96.976
,
"acc@5"
:
96.976
,
}
}
},
},
"_ops"
:
34.361
,
"_file_size"
:
754.537
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
...
torchvision/models/densenet.py
View file @
cc26cd81
...
@@ -15,7 +15,6 @@ from ._api import register_model, Weights, WeightsEnum
...
@@ -15,7 +15,6 @@ from ._api import register_model, Weights, WeightsEnum
from
._meta
import
_IMAGENET_CATEGORIES
from
._meta
import
_IMAGENET_CATEGORIES
from
._utils
import
_ovewrite_named_param
,
handle_legacy_interface
from
._utils
import
_ovewrite_named_param
,
handle_legacy_interface
__all__
=
[
__all__
=
[
"DenseNet"
,
"DenseNet"
,
"DenseNet121_Weights"
,
"DenseNet121_Weights"
,
...
@@ -228,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
...
@@ -228,7 +227,7 @@ def _load_state_dict(model: nn.Module, weights: WeightsEnum, progress: bool) ->
r
"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
r
"^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$"
)
)
state_dict
=
weights
.
get_state_dict
(
progress
=
progress
)
state_dict
=
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
)
for
key
in
list
(
state_dict
.
keys
()):
for
key
in
list
(
state_dict
.
keys
()):
res
=
pattern
.
match
(
key
)
res
=
pattern
.
match
(
key
)
if
res
:
if
res
:
...
@@ -278,6 +277,8 @@ class DenseNet121_Weights(WeightsEnum):
...
@@ -278,6 +277,8 @@ class DenseNet121_Weights(WeightsEnum):
"acc@5"
:
91.972
,
"acc@5"
:
91.972
,
}
}
},
},
"_ops"
:
2.834
,
"_file_size"
:
30.845
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -296,6 +297,8 @@ class DenseNet161_Weights(WeightsEnum):
...
@@ -296,6 +297,8 @@ class DenseNet161_Weights(WeightsEnum):
"acc@5"
:
93.560
,
"acc@5"
:
93.560
,
}
}
},
},
"_ops"
:
7.728
,
"_file_size"
:
110.369
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -314,6 +317,8 @@ class DenseNet169_Weights(WeightsEnum):
...
@@ -314,6 +317,8 @@ class DenseNet169_Weights(WeightsEnum):
"acc@5"
:
92.806
,
"acc@5"
:
92.806
,
}
}
},
},
"_ops"
:
3.36
,
"_file_size"
:
54.708
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -332,6 +337,8 @@ class DenseNet201_Weights(WeightsEnum):
...
@@ -332,6 +337,8 @@ class DenseNet201_Weights(WeightsEnum):
"acc@5"
:
93.370
,
"acc@5"
:
93.370
,
}
}
},
},
"_ops"
:
4.291
,
"_file_size"
:
77.373
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -439,17 +446,3 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool
...
@@ -439,17 +446,3 @@ def densenet201(*, weights: Optional[DenseNet201_Weights] = None, progress: bool
weights
=
DenseNet201_Weights
.
verify
(
weights
)
weights
=
DenseNet201_Weights
.
verify
(
weights
)
return
_densenet
(
32
,
(
6
,
12
,
48
,
32
),
64
,
weights
,
progress
,
**
kwargs
)
return
_densenet
(
32
,
(
6
,
12
,
48
,
32
),
64
,
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"densenet121"
:
DenseNet121_Weights
.
IMAGENET1K_V1
.
url
,
"densenet169"
:
DenseNet169_Weights
.
IMAGENET1K_V1
.
url
,
"densenet201"
:
DenseNet201_Weights
.
IMAGENET1K_V1
.
url
,
"densenet161"
:
DenseNet161_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/detection/_utils.py
View file @
cc26cd81
...
@@ -25,7 +25,7 @@ class BalancedPositiveNegativeSampler:
...
@@ -25,7 +25,7 @@ class BalancedPositiveNegativeSampler:
def
__call__
(
self
,
matched_idxs
:
List
[
Tensor
])
->
Tuple
[
List
[
Tensor
],
List
[
Tensor
]]:
def
__call__
(
self
,
matched_idxs
:
List
[
Tensor
])
->
Tuple
[
List
[
Tensor
],
List
[
Tensor
]]:
"""
"""
Args:
Args:
matched
idxs: list of tensors containing -1, 0 or positive values.
matched
_
idxs: list of tensors containing -1, 0 or positive values.
Each tensor corresponds to a specific image.
Each tensor corresponds to a specific image.
-1 values are ignored, 0 are considered as negatives and > 0 as
-1 values are ignored, 0 are considered as negatives and > 0 as
positives.
positives.
...
@@ -403,22 +403,14 @@ class Matcher:
...
@@ -403,22 +403,14 @@ class Matcher:
it is unmatched, then match it to the ground-truth with which it has the highest
it is unmatched, then match it to the ground-truth with which it has the highest
quality value.
quality value.
"""
"""
# For each gt, find the prediction with which it has highest quality
# For each gt, find the prediction with which it has
the
highest quality
highest_quality_foreach_gt
,
_
=
match_quality_matrix
.
max
(
dim
=
1
)
highest_quality_foreach_gt
,
_
=
match_quality_matrix
.
max
(
dim
=
1
)
# Find highest quality match available, even if it is low, including ties
# Find
the
highest quality match available, even if it is low, including ties
gt_pred_pairs_of_highest_quality
=
torch
.
where
(
match_quality_matrix
==
highest_quality_foreach_gt
[:,
None
])
gt_pred_pairs_of_highest_quality
=
torch
.
where
(
match_quality_matrix
==
highest_quality_foreach_gt
[:,
None
])
# Example gt_pred_pairs_of_highest_quality:
# Example gt_pred_pairs_of_highest_quality:
# tensor([[ 0, 39796],
# (tensor([0, 1, 1, 2, 2, 3, 3, 4, 5, 5]),
# [ 1, 32055],
# tensor([39796, 32055, 32070, 39190, 40255, 40390, 41455, 45470, 45325, 46390]))
# [ 1, 32070],
# Each element in the first tensor is a gt index, and each element in second tensor is a prediction index
# [ 2, 39190],
# [ 2, 40255],
# [ 3, 40390],
# [ 3, 41455],
# [ 4, 45470],
# [ 5, 45325],
# [ 5, 46390]])
# Each row is a (gt index, prediction index)
# Note how gt items 1, 2, 3, and 5 each have two ties
# Note how gt items 1, 2, 3, and 5 each have two ties
pred_inds_to_update
=
gt_pred_pairs_of_highest_quality
[
1
]
pred_inds_to_update
=
gt_pred_pairs_of_highest_quality
[
1
]
...
@@ -501,14 +493,14 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
...
@@ -501,14 +493,14 @@ def _topk_min(input: Tensor, orig_kval: int, axis: int) -> int:
if K exceeds the number of elements along that axis. Previously, python's min() function was
if K exceeds the number of elements along that axis. Previously, python's min() function was
used to determine whether to use the provided k-value or the specified dim axis value.
used to determine whether to use the provided k-value or the specified dim axis value.
However in cases where the model is being exported in tracing mode, python min() is
However
,
in cases where the model is being exported in tracing mode, python min() is
static causing the model to be traced incorrectly and eventually fail at the topk node.
static causing the model to be traced incorrectly and eventually fail at the topk node.
In order to avoid this situation, in tracing mode, torch.min() is used instead.
In order to avoid this situation, in tracing mode, torch.min() is used instead.
Args:
Args:
input (Tensor): The orignal input tensor.
input (Tensor): The orig
i
nal input tensor.
orig_kval (int): The provided k-value.
orig_kval (int): The provided k-value.
axis(int): Axis along which we retr
e
ive the input size.
axis(int): Axis along which we retri
e
ve the input size.
Returns:
Returns:
min_kval (int): Appropriately selected k-value.
min_kval (int): Appropriately selected k-value.
...
...
torchvision/models/detection/anchor_utils.py
View file @
cc26cd81
...
@@ -61,7 +61,7 @@ class AnchorGenerator(nn.Module):
...
@@ -61,7 +61,7 @@ class AnchorGenerator(nn.Module):
aspect_ratios
:
List
[
float
],
aspect_ratios
:
List
[
float
],
dtype
:
torch
.
dtype
=
torch
.
float32
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
):
)
->
Tensor
:
scales
=
torch
.
as_tensor
(
scales
,
dtype
=
dtype
,
device
=
device
)
scales
=
torch
.
as_tensor
(
scales
,
dtype
=
dtype
,
device
=
device
)
aspect_ratios
=
torch
.
as_tensor
(
aspect_ratios
,
dtype
=
dtype
,
device
=
device
)
aspect_ratios
=
torch
.
as_tensor
(
aspect_ratios
,
dtype
=
dtype
,
device
=
device
)
h_ratios
=
torch
.
sqrt
(
aspect_ratios
)
h_ratios
=
torch
.
sqrt
(
aspect_ratios
)
...
@@ -76,7 +76,7 @@ class AnchorGenerator(nn.Module):
...
@@ -76,7 +76,7 @@ class AnchorGenerator(nn.Module):
def
set_cell_anchors
(
self
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
def
set_cell_anchors
(
self
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
self
.
cell_anchors
=
[
cell_anchor
.
to
(
dtype
=
dtype
,
device
=
device
)
for
cell_anchor
in
self
.
cell_anchors
]
self
.
cell_anchors
=
[
cell_anchor
.
to
(
dtype
=
dtype
,
device
=
device
)
for
cell_anchor
in
self
.
cell_anchors
]
def
num_anchors_per_location
(
self
):
def
num_anchors_per_location
(
self
)
->
List
[
int
]
:
return
[
len
(
s
)
*
len
(
a
)
for
s
,
a
in
zip
(
self
.
sizes
,
self
.
aspect_ratios
)]
return
[
len
(
s
)
*
len
(
a
)
for
s
,
a
in
zip
(
self
.
sizes
,
self
.
aspect_ratios
)]
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
...
@@ -145,7 +145,7 @@ class DefaultBoxGenerator(nn.Module):
...
@@ -145,7 +145,7 @@ class DefaultBoxGenerator(nn.Module):
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
the ``min_ratio`` and ``max_ratio`` parameters.
the ``min_ratio`` and ``max_ratio`` parameters.
steps (List[int]], optional): It's a hyper-parameter that affects the tiling of defalt boxes. If not provided
steps (List[int]], optional): It's a hyper-parameter that affects the tiling of defa
u
lt boxes. If not provided
it will be estimated from the data.
it will be estimated from the data.
clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
...
@@ -201,7 +201,7 @@ class DefaultBoxGenerator(nn.Module):
...
@@ -201,7 +201,7 @@ class DefaultBoxGenerator(nn.Module):
_wh_pairs
.
append
(
torch
.
as_tensor
(
wh_pairs
,
dtype
=
dtype
,
device
=
device
))
_wh_pairs
.
append
(
torch
.
as_tensor
(
wh_pairs
,
dtype
=
dtype
,
device
=
device
))
return
_wh_pairs
return
_wh_pairs
def
num_anchors_per_location
(
self
):
def
num_anchors_per_location
(
self
)
->
List
[
int
]
:
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
return
[
2
+
2
*
len
(
r
)
for
r
in
self
.
aspect_ratios
]
return
[
2
+
2
*
len
(
r
)
for
r
in
self
.
aspect_ratios
]
...
...
torchvision/models/detection/backbone_utils.py
View file @
cc26cd81
...
@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module):
...
@@ -62,7 +62,7 @@ class BackboneWithFPN(nn.Module):
@
handle_legacy_interface
(
@
handle_legacy_interface
(
weights
=
(
weights
=
(
"pretrained"
,
"pretrained"
,
lambda
kwargs
:
_get_enum_from_fn
(
resnet
.
__dict__
[
kwargs
[
"backbone_name"
]])
.
from_str
(
"IMAGENET1K_V1"
)
,
lambda
kwargs
:
_get_enum_from_fn
(
resnet
.
__dict__
[
kwargs
[
"backbone_name"
]])
[
"IMAGENET1K_V1"
]
,
),
),
)
)
def
resnet_fpn_backbone
(
def
resnet_fpn_backbone
(
...
@@ -102,12 +102,12 @@ def resnet_fpn_backbone(
...
@@ -102,12 +102,12 @@ def resnet_fpn_backbone(
trainable_layers (int): number of trainable (not frozen) layers starting from final block.
trainable_layers (int): number of trainable (not frozen) layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
returned_layers (list of int): The layers of the network to return. Each entry must be in ``[1, 4]``.
By default all layers are returned.
By default
,
all layers are returned.
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
extra_blocks (ExtraFPNBlock or None): if provided, extra operations will
be performed. It is expected to take the fpn features, the original
be performed. It is expected to take the fpn features, the original
features and the names of the original features as input, and returns
features and the names of the original features as input, and returns
a new list of feature maps and their corresponding names. By
a new list of feature maps and their corresponding names. By
default a ``LastLevelMaxPool`` is used.
default
,
a ``LastLevelMaxPool`` is used.
"""
"""
backbone
=
resnet
.
__dict__
[
backbone_name
](
weights
=
weights
,
norm_layer
=
norm_layer
)
backbone
=
resnet
.
__dict__
[
backbone_name
](
weights
=
weights
,
norm_layer
=
norm_layer
)
return
_resnet_fpn_extractor
(
backbone
,
trainable_layers
,
returned_layers
,
extra_blocks
)
return
_resnet_fpn_extractor
(
backbone
,
trainable_layers
,
returned_layers
,
extra_blocks
)
...
@@ -121,7 +121,7 @@ def _resnet_fpn_extractor(
...
@@ -121,7 +121,7 @@ def _resnet_fpn_extractor(
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
BackboneWithFPN
:
)
->
BackboneWithFPN
:
# select layers that wont be frozen
# select layers that won
'
t be frozen
if
trainable_layers
<
0
or
trainable_layers
>
5
:
if
trainable_layers
<
0
or
trainable_layers
>
5
:
raise
ValueError
(
f
"Trainable layers should be in the range [0,5], got
{
trainable_layers
}
"
)
raise
ValueError
(
f
"Trainable layers should be in the range [0,5], got
{
trainable_layers
}
"
)
layers_to_train
=
[
"layer4"
,
"layer3"
,
"layer2"
,
"layer1"
,
"conv1"
][:
trainable_layers
]
layers_to_train
=
[
"layer4"
,
"layer3"
,
"layer2"
,
"layer1"
,
"conv1"
][:
trainable_layers
]
...
@@ -158,7 +158,7 @@ def _validate_trainable_layers(
...
@@ -158,7 +158,7 @@ def _validate_trainable_layers(
if
not
is_trained
:
if
not
is_trained
:
if
trainable_backbone_layers
is
not
None
:
if
trainable_backbone_layers
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
"Changing trainable_backbone_layers has no
t
effect if "
"Changing trainable_backbone_layers has no effect if "
"neither pretrained nor pretrained_backbone have been set to True, "
"neither pretrained nor pretrained_backbone have been set to True, "
f
"falling back to trainable_backbone_layers=
{
max_value
}
so that all layers are trainable"
f
"falling back to trainable_backbone_layers=
{
max_value
}
so that all layers are trainable"
)
)
...
@@ -177,7 +177,7 @@ def _validate_trainable_layers(
...
@@ -177,7 +177,7 @@ def _validate_trainable_layers(
@
handle_legacy_interface
(
@
handle_legacy_interface
(
weights
=
(
weights
=
(
"pretrained"
,
"pretrained"
,
lambda
kwargs
:
_get_enum_from_fn
(
mobilenet
.
__dict__
[
kwargs
[
"backbone_name"
]])
.
from_str
(
"IMAGENET1K_V1"
)
,
lambda
kwargs
:
_get_enum_from_fn
(
mobilenet
.
__dict__
[
kwargs
[
"backbone_name"
]])
[
"IMAGENET1K_V1"
]
,
),
),
)
)
def
mobilenet_backbone
(
def
mobilenet_backbone
(
...
@@ -208,7 +208,7 @@ def _mobilenet_extractor(
...
@@ -208,7 +208,7 @@ def _mobilenet_extractor(
stage_indices
=
[
0
]
+
[
i
for
i
,
b
in
enumerate
(
backbone
)
if
getattr
(
b
,
"_is_cn"
,
False
)]
+
[
len
(
backbone
)
-
1
]
stage_indices
=
[
0
]
+
[
i
for
i
,
b
in
enumerate
(
backbone
)
if
getattr
(
b
,
"_is_cn"
,
False
)]
+
[
len
(
backbone
)
-
1
]
num_stages
=
len
(
stage_indices
)
num_stages
=
len
(
stage_indices
)
# find the index of the layer from which we wont freeze
# find the index of the layer from which we won
'
t freeze
if
trainable_layers
<
0
or
trainable_layers
>
num_stages
:
if
trainable_layers
<
0
or
trainable_layers
>
num_stages
:
raise
ValueError
(
f
"Trainable layers should be in the range [0,
{
num_stages
}
], got
{
trainable_layers
}
"
)
raise
ValueError
(
f
"Trainable layers should be in the range [0,
{
num_stages
}
], got
{
trainable_layers
}
"
)
freeze_before
=
len
(
backbone
)
if
trainable_layers
==
0
else
stage_indices
[
num_stages
-
trainable_layers
]
freeze_before
=
len
(
backbone
)
if
trainable_layers
==
0
else
stage_indices
[
num_stages
-
trainable_layers
]
...
...
torchvision/models/detection/faster_rcnn.py
View file @
cc26cd81
...
@@ -47,9 +47,9 @@ class FasterRCNN(GeneralizedRCNN):
...
@@ -47,9 +47,9 @@ class FasterRCNN(GeneralizedRCNN):
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
The behavior of the model changes depending
on
if it is in training or evaluation mode.
During training, the model expects both the input tensors
,
a
s well as a
targets (list of dictionary),
During training, the model expects both the input tensors a
nd
targets (list of dictionary),
containing:
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
...
@@ -68,7 +68,7 @@ class FasterRCNN(GeneralizedRCNN):
...
@@ -68,7 +68,7 @@ class FasterRCNN(GeneralizedRCNN):
Args:
Args:
backbone (nn.Module): the network used to compute the features for the model.
backbone (nn.Module): the network used to compute the features for the model.
It should contain a out_channels attribute, which indicates the number of output
It should contain a
n
out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps).
channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or and OrderedDict[Tensor].
The backbone should return a single Tensor or and OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background).
num_classes (int): number of output classes of the model (including the background).
...
@@ -128,7 +128,7 @@ class FasterRCNN(GeneralizedRCNN):
...
@@ -128,7 +128,7 @@ class FasterRCNN(GeneralizedRCNN):
>>> # only the features
>>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> # FasterRCNN needs to know the number of
>>> # FasterRCNN needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
,
>>> # so we need to add it here
>>> # so we need to add it here
>>> backbone.out_channels = 1280
>>> backbone.out_channels = 1280
>>>
>>>
...
@@ -388,6 +388,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
...
@@ -388,6 +388,8 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
"box_map"
:
37.0
,
"box_map"
:
37.0
,
}
}
},
},
"_ops"
:
134.38
,
"_file_size"
:
159.743
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
},
},
)
)
...
@@ -407,6 +409,8 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
...
@@ -407,6 +409,8 @@ class FasterRCNN_ResNet50_FPN_V2_Weights(WeightsEnum):
"box_map"
:
46.7
,
"box_map"
:
46.7
,
}
}
},
},
"_ops"
:
280.371
,
"_file_size"
:
167.104
,
"_docs"
:
"""These weights were produced using an enhanced training recipe to boost the model accuracy."""
,
"_docs"
:
"""These weights were produced using an enhanced training recipe to boost the model accuracy."""
,
},
},
)
)
...
@@ -426,6 +430,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
...
@@ -426,6 +430,8 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
"box_map"
:
32.8
,
"box_map"
:
32.8
,
}
}
},
},
"_ops"
:
4.494
,
"_file_size"
:
74.239
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
},
},
)
)
...
@@ -445,6 +451,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
...
@@ -445,6 +451,8 @@ class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
"box_map"
:
22.8
,
"box_map"
:
22.8
,
}
}
},
},
"_ops"
:
0.719
,
"_file_size"
:
74.239
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
},
},
)
)
...
@@ -475,9 +483,9 @@ def fasterrcnn_resnet50_fpn(
...
@@ -475,9 +483,9 @@ def fasterrcnn_resnet50_fpn(
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
The behavior of the model changes depending
on
if it is in training or evaluation mode.
During training, the model expects both the input tensors
,
a
s well as
a targets (list of dictionary),
During training, the model expects both the input tensors a
nd
a targets (list of dictionary),
containing:
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...
@@ -563,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
...
@@ -563,7 +571,7 @@ def fasterrcnn_resnet50_fpn(
model
=
FasterRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
model
=
FasterRCNN
(
backbone
,
num_classes
=
num_classes
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
weights
==
FasterRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
if
weights
==
FasterRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
overwrite_eps
(
model
,
0.0
)
overwrite_eps
(
model
,
0.0
)
...
@@ -645,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
...
@@ -645,7 +653,7 @@ def fasterrcnn_resnet50_fpn_v2(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -686,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
...
@@ -686,7 +694,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -706,7 +714,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
...
@@ -706,7 +714,7 @@ def fasterrcnn_mobilenet_v3_large_320_fpn(
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
FasterRCNN
:
)
->
FasterRCNN
:
"""
"""
Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tun
n
ed for mobile use cases.
Low resolution Faster R-CNN model with a MobileNetV3-Large backbone tuned for mobile use cases.
.. betastatus:: detection module
.. betastatus:: detection module
...
@@ -833,16 +841,3 @@ def fasterrcnn_mobilenet_v3_large_fpn(
...
@@ -833,16 +841,3 @@ def fasterrcnn_mobilenet_v3_large_fpn(
trainable_backbone_layers
=
trainable_backbone_layers
,
trainable_backbone_layers
=
trainable_backbone_layers
,
**
kwargs
,
**
kwargs
,
)
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"fasterrcnn_resnet50_fpn_coco"
:
FasterRCNN_ResNet50_FPN_Weights
.
COCO_V1
.
url
,
"fasterrcnn_mobilenet_v3_large_320_fpn_coco"
:
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights
.
COCO_V1
.
url
,
"fasterrcnn_mobilenet_v3_large_fpn_coco"
:
FasterRCNN_MobileNet_V3_Large_FPN_Weights
.
COCO_V1
.
url
,
}
)
torchvision/models/detection/fcos.py
View file @
cc26cd81
...
@@ -70,7 +70,7 @@ class FCOSHead(nn.Module):
...
@@ -70,7 +70,7 @@ class FCOSHead(nn.Module):
else
:
else
:
gt_classes_targets
=
targets_per_image
[
"labels"
][
matched_idxs_per_image
.
clip
(
min
=
0
)]
gt_classes_targets
=
targets_per_image
[
"labels"
][
matched_idxs_per_image
.
clip
(
min
=
0
)]
gt_boxes_targets
=
targets_per_image
[
"boxes"
][
matched_idxs_per_image
.
clip
(
min
=
0
)]
gt_boxes_targets
=
targets_per_image
[
"boxes"
][
matched_idxs_per_image
.
clip
(
min
=
0
)]
gt_classes_targets
[
matched_idxs_per_image
<
0
]
=
-
1
# backgroud
gt_classes_targets
[
matched_idxs_per_image
<
0
]
=
-
1
# backgrou
n
d
all_gt_classes_targets
.
append
(
gt_classes_targets
)
all_gt_classes_targets
.
append
(
gt_classes_targets
)
all_gt_boxes_targets
.
append
(
gt_boxes_targets
)
all_gt_boxes_targets
.
append
(
gt_boxes_targets
)
...
@@ -274,9 +274,9 @@ class FCOS(nn.Module):
...
@@ -274,9 +274,9 @@ class FCOS(nn.Module):
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
The behavior of the model changes depending
on
if it is in training or evaluation mode.
During training, the model expects both the input tensors
,
a
s well as a
targets (list of dictionary),
During training, the model expects both the input tensors a
nd
targets (list of dictionary),
containing:
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``.
...
@@ -329,7 +329,7 @@ class FCOS(nn.Module):
...
@@ -329,7 +329,7 @@ class FCOS(nn.Module):
>>> # only the features
>>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> # FCOS needs to know the number of
>>> # FCOS needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
,
>>> # so we need to add it here
>>> # so we need to add it here
>>> backbone.out_channels = 1280
>>> backbone.out_channels = 1280
>>>
>>>
...
@@ -662,6 +662,8 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
...
@@ -662,6 +662,8 @@ class FCOS_ResNet50_FPN_Weights(WeightsEnum):
"box_map"
:
39.2
,
"box_map"
:
39.2
,
}
}
},
},
"_ops"
:
128.207
,
"_file_size"
:
123.608
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
},
},
)
)
...
@@ -693,9 +695,9 @@ def fcos_resnet50_fpn(
...
@@ -693,9 +695,9 @@ def fcos_resnet50_fpn(
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
The behavior of the model changes depending
on
if it is in training or evaluation mode.
During training, the model expects both the input tensors
,
a
s well as a
targets (list of dictionary),
During training, the model expects both the input tensors a
nd
targets (list of dictionary),
containing:
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...
@@ -764,17 +766,6 @@ def fcos_resnet50_fpn(
...
@@ -764,17 +766,6 @@ def fcos_resnet50_fpn(
model
=
FCOS
(
backbone
,
num_classes
,
**
kwargs
)
model
=
FCOS
(
backbone
,
num_classes
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"fcos_resnet50_fpn_coco"
:
FCOS_ResNet50_FPN_Weights
.
COCO_V1
.
url
,
}
)
torchvision/models/detection/keypoint_rcnn.py
View file @
cc26cd81
...
@@ -29,9 +29,9 @@ class KeypointRCNN(FasterRCNN):
...
@@ -29,9 +29,9 @@ class KeypointRCNN(FasterRCNN):
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each
image, and should be in 0-1 range. Different images can have different sizes.
image, and should be in 0-1 range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
The behavior of the model changes depending
on
if it is in training or evaluation mode.
During training, the model expects both the input tensors
,
a
s well as a
targets (list of dictionary),
During training, the model expects both the input tensors a
nd
targets (list of dictionary),
containing:
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...
@@ -55,7 +55,7 @@ class KeypointRCNN(FasterRCNN):
...
@@ -55,7 +55,7 @@ class KeypointRCNN(FasterRCNN):
Args:
Args:
backbone (nn.Module): the network used to compute the features for the model.
backbone (nn.Module): the network used to compute the features for the model.
It should contain a out_channels attribute, which indicates the number of output
It should contain a
n
out_channels attribute, which indicates the number of output
channels that each feature map has (and it should be the same for all feature maps).
channels that each feature map has (and it should be the same for all feature maps).
The backbone should return a single Tensor or and OrderedDict[Tensor].
The backbone should return a single Tensor or and OrderedDict[Tensor].
num_classes (int): number of output classes of the model (including the background).
num_classes (int): number of output classes of the model (including the background).
...
@@ -121,7 +121,7 @@ class KeypointRCNN(FasterRCNN):
...
@@ -121,7 +121,7 @@ class KeypointRCNN(FasterRCNN):
>>> # only the features
>>> # only the features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> backbone = torchvision.models.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).features
>>> # KeypointRCNN needs to know the number of
>>> # KeypointRCNN needs to know the number of
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
>>> # output channels in a backbone. For mobilenet_v2, it's 1280
,
>>> # so we need to add it here
>>> # so we need to add it here
>>> backbone.out_channels = 1280
>>> backbone.out_channels = 1280
>>>
>>>
...
@@ -328,6 +328,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
...
@@ -328,6 +328,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"kp_map"
:
61.1
,
"kp_map"
:
61.1
,
}
}
},
},
"_ops"
:
133.924
,
"_file_size"
:
226.054
,
"_docs"
:
"""
"_docs"
:
"""
These weights were produced by following a similar training recipe as on the paper but use a checkpoint
These weights were produced by following a similar training recipe as on the paper but use a checkpoint
from an early epoch.
from an early epoch.
...
@@ -347,6 +349,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
...
@@ -347,6 +349,8 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
"kp_map"
:
65.0
,
"kp_map"
:
65.0
,
}
}
},
},
"_ops"
:
137.42
,
"_file_size"
:
226.054
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
"_docs"
:
"""These weights were produced by following a similar training recipe as on the paper."""
,
},
},
)
)
...
@@ -383,9 +387,9 @@ def keypointrcnn_resnet50_fpn(
...
@@ -383,9 +387,9 @@ def keypointrcnn_resnet50_fpn(
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each
image, and should be in ``0-1`` range. Different images can have different sizes.
image, and should be in ``0-1`` range. Different images can have different sizes.
The behavior of the model changes depending if it is in training or evaluation mode.
The behavior of the model changes depending
on
if it is in training or evaluation mode.
During training, the model expects both the input tensors
,
a
s well as a
targets (list of dictionary),
During training, the model expects both the input tensors a
nd
targets (list of dictionary),
containing:
containing:
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with
...
@@ -461,21 +465,8 @@ def keypointrcnn_resnet50_fpn(
...
@@ -461,21 +465,8 @@ def keypointrcnn_resnet50_fpn(
model
=
KeypointRCNN
(
backbone
,
num_classes
,
num_keypoints
=
num_keypoints
,
**
kwargs
)
model
=
KeypointRCNN
(
backbone
,
num_classes
,
num_keypoints
=
num_keypoints
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
if
weights
==
KeypointRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
if
weights
==
KeypointRCNN_ResNet50_FPN_Weights
.
COCO_V1
:
overwrite_eps
(
model
,
0.0
)
overwrite_eps
(
model
,
0.0
)
return
model
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
# legacy model for BC reasons, see https://github.com/pytorch/vision/issues/1606
"keypointrcnn_resnet50_fpn_coco_legacy"
:
KeypointRCNN_ResNet50_FPN_Weights
.
COCO_LEGACY
.
url
,
"keypointrcnn_resnet50_fpn_coco"
:
KeypointRCNN_ResNet50_FPN_Weights
.
COCO_V1
.
url
,
}
)
Prev
1
…
11
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment