Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
cc26cd81
Commit
cc26cd81
authored
Nov 27, 2023
by
panning
Browse files
merge v0.16.0
parents
f78f29f5
fbb4cc54
Changes
370
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1098 additions
and
253 deletions
+1098
-253
torchvision/models/quantization/shufflenetv2.py
torchvision/models/quantization/shufflenetv2.py
+10
-15
torchvision/models/regnet.py
torchvision/models/regnet.py
+70
-26
torchvision/models/resnet.py
torchvision/models/resnet.py
+41
-26
torchvision/models/segmentation/deeplabv3.py
torchvision/models/segmentation/deeplabv3.py
+9
-16
torchvision/models/segmentation/fcn.py
torchvision/models/segmentation/fcn.py
+6
-14
torchvision/models/segmentation/lraspp.py
torchvision/models/segmentation/lraspp.py
+3
-12
torchvision/models/shufflenetv2.py
torchvision/models/shufflenetv2.py
+10
-16
torchvision/models/squeezenet.py
torchvision/models/squeezenet.py
+5
-13
torchvision/models/swin_transformer.py
torchvision/models/swin_transformer.py
+31
-13
torchvision/models/vgg.py
torchvision/models/vgg.py
+19
-19
torchvision/models/video/__init__.py
torchvision/models/video/__init__.py
+1
-0
torchvision/models/video/mvit.py
torchvision/models/video/mvit.py
+10
-5
torchvision/models/video/resnet.py
torchvision/models/video/resnet.py
+7
-1
torchvision/models/video/s3d.py
torchvision/models/video/s3d.py
+3
-1
torchvision/models/video/swin_transformer.py
torchvision/models/video/swin_transformer.py
+743
-0
torchvision/models/vision_transformer.py
torchvision/models/vision_transformer.py
+24
-18
torchvision/ops/_box_convert.py
torchvision/ops/_box_convert.py
+1
-1
torchvision/ops/_register_onnx_ops.py
torchvision/ops/_register_onnx_ops.py
+95
-54
torchvision/ops/ciou_loss.py
torchvision/ops/ciou_loss.py
+9
-2
torchvision/ops/deform_conv.py
torchvision/ops/deform_conv.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
370 of 370+
files are displayed.
Plain diff
Email patch
torchvision/models/quantization/shufflenetv2.py
View file @
cc26cd81
...
@@ -108,7 +108,7 @@ def _shufflenetv2(
...
@@ -108,7 +108,7 @@ def _shufflenetv2(
quantize_model
(
model
,
backend
)
quantize_model
(
model
,
backend
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -139,6 +139,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
...
@@ -139,6 +139,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
"acc@5"
:
79.780
,
"acc@5"
:
79.780
,
}
}
},
},
"_ops"
:
0.04
,
"_file_size"
:
1.501
,
},
},
)
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
@@ -146,7 +148,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
...
@@ -146,7 +148,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
class
ShuffleNet_V2_X1_0_QuantizedWeights
(
WeightsEnum
):
class
ShuffleNet_V2_X1_0_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-
db332c57
.pth"
,
url
=
"https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-
1e62bb32
.pth"
,
transforms
=
partial
(
ImageClassification
,
crop_size
=
224
),
transforms
=
partial
(
ImageClassification
,
crop_size
=
224
),
meta
=
{
meta
=
{
**
_COMMON_META
,
**
_COMMON_META
,
...
@@ -158,6 +160,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
...
@@ -158,6 +160,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
"acc@5"
:
87.582
,
"acc@5"
:
87.582
,
}
}
},
},
"_ops"
:
0.145
,
"_file_size"
:
2.334
,
},
},
)
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
@@ -178,6 +182,8 @@ class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum):
...
@@ -178,6 +182,8 @@ class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum):
"acc@5"
:
90.700
,
"acc@5"
:
90.700
,
}
}
},
},
"_ops"
:
0.296
,
"_file_size"
:
3.672
,
},
},
)
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
@@ -198,6 +204,8 @@ class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
...
@@ -198,6 +204,8 @@ class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
"acc@5"
:
92.488
,
"acc@5"
:
92.488
,
}
}
},
},
"_ops"
:
0.583
,
"_file_size"
:
7.467
,
},
},
)
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
@@ -417,16 +425,3 @@ def shufflenet_v2_x2_0(
...
@@ -417,16 +425,3 @@ def shufflenet_v2_x2_0(
return
_shufflenetv2
(
return
_shufflenetv2
(
[
4
,
8
,
4
],
[
24
,
244
,
488
,
976
,
2048
],
weights
=
weights
,
progress
=
progress
,
quantize
=
quantize
,
**
kwargs
[
4
,
8
,
4
],
[
24
,
244
,
488
,
976
,
2048
],
weights
=
weights
,
progress
=
progress
,
quantize
=
quantize
,
**
kwargs
)
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
from
..shufflenetv2
import
model_urls
# noqa: F401
quant_model_urls
=
_ModelURLs
(
{
"shufflenetv2_x0.5_fbgemm"
:
ShuffleNet_V2_X0_5_QuantizedWeights
.
IMAGENET1K_FBGEMM_V1
.
url
,
"shufflenetv2_x1.0_fbgemm"
:
ShuffleNet_V2_X1_0_QuantizedWeights
.
IMAGENET1K_FBGEMM_V1
.
url
,
}
)
torchvision/models/regnet.py
View file @
cc26cd81
...
@@ -212,7 +212,7 @@ class BlockParams:
...
@@ -212,7 +212,7 @@ class BlockParams:
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
"BlockParams"
:
)
->
"BlockParams"
:
"""
"""
Programatically compute all the per-block settings,
Program
m
atically compute all the per-block settings,
given the RegNet parameters.
given the RegNet parameters.
The first step is to compute the quantized linear block parameters,
The first step is to compute the quantized linear block parameters,
...
@@ -397,7 +397,7 @@ def _regnet(
...
@@ -397,7 +397,7 @@ def _regnet(
model
=
RegNet
(
block_params
,
norm_layer
=
norm_layer
,
**
kwargs
)
model
=
RegNet
(
block_params
,
norm_layer
=
norm_layer
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -428,6 +428,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
...
@@ -428,6 +428,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@5"
:
91.716
,
"acc@5"
:
91.716
,
}
}
},
},
"_ops"
:
0.402
,
"_file_size"
:
16.806
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -444,6 +446,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
...
@@ -444,6 +446,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@5"
:
92.742
,
"acc@5"
:
92.742
,
}
}
},
},
"_ops"
:
0.402
,
"_file_size"
:
16.806
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -468,6 +472,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
...
@@ -468,6 +472,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@5"
:
93.136
,
"acc@5"
:
93.136
,
}
}
},
},
"_ops"
:
0.834
,
"_file_size"
:
24.774
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -484,6 +490,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
...
@@ -484,6 +490,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@5"
:
94.502
,
"acc@5"
:
94.502
,
}
}
},
},
"_ops"
:
0.834
,
"_file_size"
:
24.774
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -508,6 +516,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
...
@@ -508,6 +516,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@5"
:
93.966
,
"acc@5"
:
93.966
,
}
}
},
},
"_ops"
:
1.612
,
"_file_size"
:
43.152
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -524,6 +534,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
...
@@ -524,6 +534,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@5"
:
95.444
,
"acc@5"
:
95.444
,
}
}
},
},
"_ops"
:
1.612
,
"_file_size"
:
43.152
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -548,6 +560,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
...
@@ -548,6 +560,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@5"
:
94.576
,
"acc@5"
:
94.576
,
}
}
},
},
"_ops"
:
3.176
,
"_file_size"
:
74.567
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -564,6 +578,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
...
@@ -564,6 +578,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@5"
:
95.972
,
"acc@5"
:
95.972
,
}
}
},
},
"_ops"
:
3.176
,
"_file_size"
:
74.567
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -588,6 +604,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
...
@@ -588,6 +604,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@5"
:
95.048
,
"acc@5"
:
95.048
,
}
}
},
},
"_ops"
:
8.473
,
"_file_size"
:
150.701
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -604,6 +622,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
...
@@ -604,6 +622,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@5"
:
96.330
,
"acc@5"
:
96.330
,
}
}
},
},
"_ops"
:
8.473
,
"_file_size"
:
150.701
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -628,6 +648,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
...
@@ -628,6 +648,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
95.240
,
"acc@5"
:
95.240
,
}
}
},
},
"_ops"
:
15.912
,
"_file_size"
:
319.49
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -644,6 +666,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
...
@@ -644,6 +666,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
96.328
,
"acc@5"
:
96.328
,
}
}
},
},
"_ops"
:
15.912
,
"_file_size"
:
319.49
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -665,6 +689,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
...
@@ -665,6 +689,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
98.054
,
"acc@5"
:
98.054
,
}
}
},
},
"_ops"
:
46.735
,
"_file_size"
:
319.49
,
"_docs"
:
"""
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
@@ -686,6 +712,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
...
@@ -686,6 +712,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
97.244
,
"acc@5"
:
97.244
,
}
}
},
},
"_ops"
:
15.912
,
"_file_size"
:
319.49
,
"_docs"
:
"""
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
@@ -709,6 +737,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
...
@@ -709,6 +737,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
95.340
,
"acc@5"
:
95.340
,
}
}
},
},
"_ops"
:
32.28
,
"_file_size"
:
554.076
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -725,6 +755,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
...
@@ -725,6 +755,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
96.498
,
"acc@5"
:
96.498
,
}
}
},
},
"_ops"
:
32.28
,
"_file_size"
:
554.076
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -746,6 +778,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
...
@@ -746,6 +778,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
98.362
,
"acc@5"
:
98.362
,
}
}
},
},
"_ops"
:
94.826
,
"_file_size"
:
554.076
,
"_docs"
:
"""
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
@@ -767,6 +801,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
...
@@ -767,6 +801,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
97.480
,
"acc@5"
:
97.480
,
}
}
},
},
"_ops"
:
32.28
,
"_file_size"
:
554.076
,
"_docs"
:
"""
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
@@ -791,6 +827,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
...
@@ -791,6 +827,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5"
:
98.682
,
"acc@5"
:
98.682
,
}
}
},
},
"_ops"
:
374.57
,
"_file_size"
:
2461.564
,
"_docs"
:
"""
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
@@ -812,6 +850,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
...
@@ -812,6 +850,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5"
:
97.844
,
"acc@5"
:
97.844
,
}
}
},
},
"_ops"
:
127.518
,
"_file_size"
:
2461.564
,
"_docs"
:
"""
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
@@ -835,6 +875,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
...
@@ -835,6 +875,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@5"
:
90.950
,
"acc@5"
:
90.950
,
}
}
},
},
"_ops"
:
0.414
,
"_file_size"
:
21.258
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -851,6 +893,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
...
@@ -851,6 +893,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@5"
:
92.322
,
"acc@5"
:
92.322
,
}
}
},
},
"_ops"
:
0.414
,
"_file_size"
:
21.257
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -875,6 +919,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
...
@@ -875,6 +919,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@5"
:
92.348
,
"acc@5"
:
92.348
,
}
}
},
},
"_ops"
:
0.8
,
"_file_size"
:
27.945
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -891,6 +937,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
...
@@ -891,6 +937,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@5"
:
93.826
,
"acc@5"
:
93.826
,
}
}
},
},
"_ops"
:
0.8
,
"_file_size"
:
27.945
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -915,6 +963,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
...
@@ -915,6 +963,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@5"
:
93.440
,
"acc@5"
:
93.440
,
}
}
},
},
"_ops"
:
1.603
,
"_file_size"
:
35.339
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -931,6 +981,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
...
@@ -931,6 +981,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@5"
:
94.922
,
"acc@5"
:
94.922
,
}
}
},
},
"_ops"
:
1.603
,
"_file_size"
:
35.339
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -955,6 +1007,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
...
@@ -955,6 +1007,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@5"
:
93.992
,
"acc@5"
:
93.992
,
}
}
},
},
"_ops"
:
3.177
,
"_file_size"
:
58.756
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -971,6 +1025,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
...
@@ -971,6 +1025,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@5"
:
95.430
,
"acc@5"
:
95.430
,
}
}
},
},
"_ops"
:
3.177
,
"_file_size"
:
58.756
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -995,6 +1051,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
...
@@ -995,6 +1051,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@5"
:
94.686
,
"acc@5"
:
94.686
,
}
}
},
},
"_ops"
:
7.995
,
"_file_size"
:
151.456
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -1011,6 +1069,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
...
@@ -1011,6 +1069,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@5"
:
95.678
,
"acc@5"
:
95.678
,
}
}
},
},
"_ops"
:
7.995
,
"_file_size"
:
151.456
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -1035,6 +1095,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
...
@@ -1035,6 +1095,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@5"
:
94.944
,
"acc@5"
:
94.944
,
}
}
},
},
"_ops"
:
15.941
,
"_file_size"
:
207.627
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -1051,6 +1113,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
...
@@ -1051,6 +1113,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@5"
:
96.196
,
"acc@5"
:
96.196
,
}
}
},
},
"_ops"
:
15.941
,
"_file_size"
:
207.627
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -1075,6 +1139,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
...
@@ -1075,6 +1139,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@5"
:
95.248
,
"acc@5"
:
95.248
,
}
}
},
},
"_ops"
:
31.736
,
"_file_size"
:
412.039
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -1091,6 +1157,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
...
@@ -1091,6 +1157,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@5"
:
96.288
,
"acc@5"
:
96.288
,
}
}
},
},
"_ops"
:
31.736
,
"_file_size"
:
412.039
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -1501,27 +1569,3 @@ def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress:
...
@@ -1501,27 +1569,3 @@ def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress:
params
=
BlockParams
.
from_init_params
(
depth
=
23
,
w_0
=
320
,
w_a
=
69.86
,
w_m
=
2.0
,
group_width
=
168
,
**
kwargs
)
params
=
BlockParams
.
from_init_params
(
depth
=
23
,
w_0
=
320
,
w_a
=
69.86
,
w_m
=
2.0
,
group_width
=
168
,
**
kwargs
)
return
_regnet
(
params
,
weights
,
progress
,
**
kwargs
)
return
_regnet
(
params
,
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"regnet_y_400mf"
:
RegNet_Y_400MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_800mf"
:
RegNet_Y_800MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_1_6gf"
:
RegNet_Y_1_6GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_3_2gf"
:
RegNet_Y_3_2GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_8gf"
:
RegNet_Y_8GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_16gf"
:
RegNet_Y_16GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_32gf"
:
RegNet_Y_32GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_400mf"
:
RegNet_X_400MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_800mf"
:
RegNet_X_800MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_1_6gf"
:
RegNet_X_1_6GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_3_2gf"
:
RegNet_X_3_2GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_8gf"
:
RegNet_X_8GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_16gf"
:
RegNet_X_16GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_32gf"
:
RegNet_X_32GF_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/resnet.py
View file @
cc26cd81
...
@@ -108,7 +108,7 @@ class BasicBlock(nn.Module):
...
@@ -108,7 +108,7 @@ class BasicBlock(nn.Module):
class
Bottleneck
(
nn
.
Module
):
class
Bottleneck
(
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# according to "Deep residual learning for image recognition"
https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
...
@@ -298,7 +298,7 @@ def _resnet(
...
@@ -298,7 +298,7 @@ def _resnet(
model
=
ResNet
(
block
,
layers
,
**
kwargs
)
model
=
ResNet
(
block
,
layers
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -323,6 +323,8 @@ class ResNet18_Weights(WeightsEnum):
...
@@ -323,6 +323,8 @@ class ResNet18_Weights(WeightsEnum):
"acc@5"
:
89.078
,
"acc@5"
:
89.078
,
}
}
},
},
"_ops"
:
1.814
,
"_file_size"
:
44.661
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -343,6 +345,8 @@ class ResNet34_Weights(WeightsEnum):
...
@@ -343,6 +345,8 @@ class ResNet34_Weights(WeightsEnum):
"acc@5"
:
91.420
,
"acc@5"
:
91.420
,
}
}
},
},
"_ops"
:
3.664
,
"_file_size"
:
83.275
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -363,6 +367,8 @@ class ResNet50_Weights(WeightsEnum):
...
@@ -363,6 +367,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@5"
:
92.862
,
"acc@5"
:
92.862
,
}
}
},
},
"_ops"
:
4.089
,
"_file_size"
:
97.781
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -379,6 +385,8 @@ class ResNet50_Weights(WeightsEnum):
...
@@ -379,6 +385,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@5"
:
95.434
,
"acc@5"
:
95.434
,
}
}
},
},
"_ops"
:
4.089
,
"_file_size"
:
97.79
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -402,6 +410,8 @@ class ResNet101_Weights(WeightsEnum):
...
@@ -402,6 +410,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@5"
:
93.546
,
"acc@5"
:
93.546
,
}
}
},
},
"_ops"
:
7.801
,
"_file_size"
:
170.511
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -418,6 +428,8 @@ class ResNet101_Weights(WeightsEnum):
...
@@ -418,6 +428,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@5"
:
95.780
,
"acc@5"
:
95.780
,
}
}
},
},
"_ops"
:
7.801
,
"_file_size"
:
170.53
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -441,6 +453,8 @@ class ResNet152_Weights(WeightsEnum):
...
@@ -441,6 +453,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@5"
:
94.046
,
"acc@5"
:
94.046
,
}
}
},
},
"_ops"
:
11.514
,
"_file_size"
:
230.434
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -457,6 +471,8 @@ class ResNet152_Weights(WeightsEnum):
...
@@ -457,6 +471,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@5"
:
96.002
,
"acc@5"
:
96.002
,
}
}
},
},
"_ops"
:
11.514
,
"_file_size"
:
230.474
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -480,6 +496,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
...
@@ -480,6 +496,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@5"
:
93.698
,
"acc@5"
:
93.698
,
}
}
},
},
"_ops"
:
4.23
,
"_file_size"
:
95.789
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -496,6 +514,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
...
@@ -496,6 +514,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@5"
:
95.340
,
"acc@5"
:
95.340
,
}
}
},
},
"_ops"
:
4.23
,
"_file_size"
:
95.833
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -519,6 +539,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
...
@@ -519,6 +539,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@5"
:
94.526
,
"acc@5"
:
94.526
,
}
}
},
},
"_ops"
:
16.414
,
"_file_size"
:
339.586
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -535,6 +557,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
...
@@ -535,6 +557,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@5"
:
96.228
,
"acc@5"
:
96.228
,
}
}
},
},
"_ops"
:
16.414
,
"_file_size"
:
339.673
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -558,6 +582,8 @@ class ResNeXt101_64X4D_Weights(WeightsEnum):
...
@@ -558,6 +582,8 @@ class ResNeXt101_64X4D_Weights(WeightsEnum):
"acc@5"
:
96.454
,
"acc@5"
:
96.454
,
}
}
},
},
"_ops"
:
15.46
,
"_file_size"
:
319.318
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained from scratch by using TorchVision's `new training recipe
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -581,6 +607,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
...
@@ -581,6 +607,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@5"
:
94.086
,
"acc@5"
:
94.086
,
}
}
},
},
"_ops"
:
11.398
,
"_file_size"
:
131.82
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -597,6 +625,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
...
@@ -597,6 +625,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@5"
:
95.758
,
"acc@5"
:
95.758
,
}
}
},
},
"_ops"
:
11.398
,
"_file_size"
:
263.124
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -620,6 +650,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
...
@@ -620,6 +650,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@5"
:
94.284
,
"acc@5"
:
94.284
,
}
}
},
},
"_ops"
:
22.753
,
"_file_size"
:
242.896
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
},
)
)
...
@@ -636,6 +668,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
...
@@ -636,6 +668,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@5"
:
96.020
,
"acc@5"
:
96.020
,
}
}
},
},
"_ops"
:
22.753
,
"_file_size"
:
484.747
,
"_docs"
:
"""
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -648,7 +682,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
...
@@ -648,7 +682,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
@
register_model
()
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet18_Weights
.
IMAGENET1K_V1
))
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet18_Weights
.
IMAGENET1K_V1
))
def
resnet18
(
*
,
weights
:
Optional
[
ResNet18_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
def
resnet18
(
*
,
weights
:
Optional
[
ResNet18_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
Args:
Args:
weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
...
@@ -674,7 +708,7 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru
...
@@ -674,7 +708,7 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru
@
register_model
()
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet34_Weights
.
IMAGENET1K_V1
))
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet34_Weights
.
IMAGENET1K_V1
))
def
resnet34
(
*
,
weights
:
Optional
[
ResNet34_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
def
resnet34
(
*
,
weights
:
Optional
[
ResNet34_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
Args:
Args:
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
...
@@ -700,7 +734,7 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru
...
@@ -700,7 +734,7 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru
@
register_model
()
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet50_Weights
.
IMAGENET1K_V1
))
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet50_Weights
.
IMAGENET1K_V1
))
def
resnet50
(
*
,
weights
:
Optional
[
ResNet50_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
def
resnet50
(
*
,
weights
:
Optional
[
ResNet50_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
.. note::
.. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...
@@ -732,7 +766,7 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru
...
@@ -732,7 +766,7 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru
@
register_model
()
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet101_Weights
.
IMAGENET1K_V1
))
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet101_Weights
.
IMAGENET1K_V1
))
def
resnet101
(
*
,
weights
:
Optional
[
ResNet101_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
def
resnet101
(
*
,
weights
:
Optional
[
ResNet101_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
.. note::
.. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...
@@ -764,7 +798,7 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T
...
@@ -764,7 +798,7 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T
@
register_model
()
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet152_Weights
.
IMAGENET1K_V1
))
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet152_Weights
.
IMAGENET1K_V1
))
def
resnet152
(
*
,
weights
:
Optional
[
ResNet152_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
def
resnet152
(
*
,
weights
:
Optional
[
ResNet152_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
.. note::
.. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...
@@ -949,22 +983,3 @@ def wide_resnet101_2(
...
@@ -949,22 +983,3 @@ def wide_resnet101_2(
_ovewrite_named_param
(
kwargs
,
"width_per_group"
,
64
*
2
)
_ovewrite_named_param
(
kwargs
,
"width_per_group"
,
64
*
2
)
return
_resnet
(
Bottleneck
,
[
3
,
4
,
23
,
3
],
weights
,
progress
,
**
kwargs
)
return
_resnet
(
Bottleneck
,
[
3
,
4
,
23
,
3
],
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"resnet18"
:
ResNet18_Weights
.
IMAGENET1K_V1
.
url
,
"resnet34"
:
ResNet34_Weights
.
IMAGENET1K_V1
.
url
,
"resnet50"
:
ResNet50_Weights
.
IMAGENET1K_V1
.
url
,
"resnet101"
:
ResNet101_Weights
.
IMAGENET1K_V1
.
url
,
"resnet152"
:
ResNet152_Weights
.
IMAGENET1K_V1
.
url
,
"resnext50_32x4d"
:
ResNeXt50_32X4D_Weights
.
IMAGENET1K_V1
.
url
,
"resnext101_32x8d"
:
ResNeXt101_32X8D_Weights
.
IMAGENET1K_V1
.
url
,
"wide_resnet50_2"
:
Wide_ResNet50_2_Weights
.
IMAGENET1K_V1
.
url
,
"wide_resnet101_2"
:
Wide_ResNet101_2_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/segmentation/deeplabv3.py
View file @
cc26cd81
...
@@ -152,6 +152,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
...
@@ -152,6 +152,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
"pixel_acc"
:
92.4
,
"pixel_acc"
:
92.4
,
}
}
},
},
"_ops"
:
178.722
,
"_file_size"
:
160.515
,
},
},
)
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
@@ -171,6 +173,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
...
@@ -171,6 +173,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
"pixel_acc"
:
92.4
,
"pixel_acc"
:
92.4
,
}
}
},
},
"_ops"
:
258.743
,
"_file_size"
:
233.217
,
},
},
)
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
@@ -190,6 +194,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
...
@@ -190,6 +194,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
"pixel_acc"
:
91.2
,
"pixel_acc"
:
91.2
,
}
}
},
},
"_ops"
:
10.452
,
"_file_size"
:
42.301
,
},
},
)
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
@@ -269,7 +275,7 @@ def deeplabv3_resnet50(
...
@@ -269,7 +275,7 @@ def deeplabv3_resnet50(
model
=
_deeplabv3_resnet
(
backbone
,
num_classes
,
aux_loss
)
model
=
_deeplabv3_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -325,7 +331,7 @@ def deeplabv3_resnet101(
...
@@ -325,7 +331,7 @@ def deeplabv3_resnet101(
model
=
_deeplabv3_resnet
(
backbone
,
num_classes
,
aux_loss
)
model
=
_deeplabv3_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -379,19 +385,6 @@ def deeplabv3_mobilenet_v3_large(
...
@@ -379,19 +385,6 @@ def deeplabv3_mobilenet_v3_large(
model
=
_deeplabv3_mobilenetv3
(
backbone
,
num_classes
,
aux_loss
)
model
=
_deeplabv3_mobilenetv3
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"deeplabv3_resnet50_coco"
:
DeepLabV3_ResNet50_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
"deeplabv3_resnet101_coco"
:
DeepLabV3_ResNet101_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
"deeplabv3_mobilenet_v3_large_coco"
:
DeepLabV3_MobileNet_V3_Large_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
}
)
torchvision/models/segmentation/fcn.py
View file @
cc26cd81
...
@@ -71,6 +71,8 @@ class FCN_ResNet50_Weights(WeightsEnum):
...
@@ -71,6 +71,8 @@ class FCN_ResNet50_Weights(WeightsEnum):
"pixel_acc"
:
91.4
,
"pixel_acc"
:
91.4
,
}
}
},
},
"_ops"
:
152.717
,
"_file_size"
:
135.009
,
},
},
)
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
@@ -90,6 +92,8 @@ class FCN_ResNet101_Weights(WeightsEnum):
...
@@ -90,6 +92,8 @@ class FCN_ResNet101_Weights(WeightsEnum):
"pixel_acc"
:
91.9
,
"pixel_acc"
:
91.9
,
}
}
},
},
"_ops"
:
232.738
,
"_file_size"
:
207.711
,
},
},
)
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
@@ -164,7 +168,7 @@ def fcn_resnet50(
...
@@ -164,7 +168,7 @@ def fcn_resnet50(
model
=
_fcn_resnet
(
backbone
,
num_classes
,
aux_loss
)
model
=
_fcn_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -223,18 +227,6 @@ def fcn_resnet101(
...
@@ -223,18 +227,6 @@ def fcn_resnet101(
model
=
_fcn_resnet
(
backbone
,
num_classes
,
aux_loss
)
model
=
_fcn_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"fcn_resnet50_coco"
:
FCN_ResNet50_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
"fcn_resnet101_coco"
:
FCN_ResNet101_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
}
)
torchvision/models/segmentation/lraspp.py
View file @
cc26cd81
...
@@ -108,6 +108,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
...
@@ -108,6 +108,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
"pixel_acc"
:
91.2
,
"pixel_acc"
:
91.2
,
}
}
},
},
"_ops"
:
2.086
,
"_file_size"
:
12.49
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained on a subset of COCO, using only the 20 categories that are present in the
These weights were trained on a subset of COCO, using only the 20 categories that are present in the
Pascal VOC dataset.
Pascal VOC dataset.
...
@@ -171,17 +173,6 @@ def lraspp_mobilenet_v3_large(
...
@@ -171,17 +173,6 @@ def lraspp_mobilenet_v3_large(
model
=
_lraspp_mobilenetv3
(
backbone
,
num_classes
)
model
=
_lraspp_mobilenetv3
(
backbone
,
num_classes
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"lraspp_mobilenet_v3_large_coco"
:
LRASPP_MobileNet_V3_Large_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
}
)
torchvision/models/shufflenetv2.py
View file @
cc26cd81
...
@@ -35,7 +35,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor:
...
@@ -35,7 +35,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor:
x
=
torch
.
transpose
(
x
,
1
,
2
).
contiguous
()
x
=
torch
.
transpose
(
x
,
1
,
2
).
contiguous
()
# flatten
# flatten
x
=
x
.
view
(
batchsize
,
-
1
,
height
,
width
)
x
=
x
.
view
(
batchsize
,
num_channels
,
height
,
width
)
return
x
return
x
...
@@ -178,7 +178,7 @@ def _shufflenetv2(
...
@@ -178,7 +178,7 @@ def _shufflenetv2(
model
=
ShuffleNetV2
(
*
args
,
**
kwargs
)
model
=
ShuffleNetV2
(
*
args
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -204,6 +204,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
...
@@ -204,6 +204,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
"acc@5"
:
81.746
,
"acc@5"
:
81.746
,
}
}
},
},
"_ops"
:
0.04
,
"_file_size"
:
5.282
,
"_docs"
:
"""These weights were trained from scratch to reproduce closely the results of the paper."""
,
"_docs"
:
"""These weights were trained from scratch to reproduce closely the results of the paper."""
,
},
},
)
)
...
@@ -224,6 +226,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
...
@@ -224,6 +226,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
"acc@5"
:
88.316
,
"acc@5"
:
88.316
,
}
}
},
},
"_ops"
:
0.145
,
"_file_size"
:
8.791
,
"_docs"
:
"""These weights were trained from scratch to reproduce closely the results of the paper."""
,
"_docs"
:
"""These weights were trained from scratch to reproduce closely the results of the paper."""
,
},
},
)
)
...
@@ -244,6 +248,8 @@ class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
...
@@ -244,6 +248,8 @@ class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
"acc@5"
:
91.086
,
"acc@5"
:
91.086
,
}
}
},
},
"_ops"
:
0.296
,
"_file_size"
:
13.557
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained from scratch by using TorchVision's `new training recipe
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -267,6 +273,8 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
...
@@ -267,6 +273,8 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
"acc@5"
:
93.006
,
"acc@5"
:
93.006
,
}
}
},
},
"_ops"
:
0.583
,
"_file_size"
:
28.433
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained from scratch by using TorchVision's `new training recipe
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
@@ -398,17 +406,3 @@ def shufflenet_v2_x2_0(
...
@@ -398,17 +406,3 @@ def shufflenet_v2_x2_0(
weights
=
ShuffleNet_V2_X2_0_Weights
.
verify
(
weights
)
weights
=
ShuffleNet_V2_X2_0_Weights
.
verify
(
weights
)
return
_shufflenetv2
(
weights
,
progress
,
[
4
,
8
,
4
],
[
24
,
244
,
488
,
976
,
2048
],
**
kwargs
)
return
_shufflenetv2
(
weights
,
progress
,
[
4
,
8
,
4
],
[
24
,
244
,
488
,
976
,
2048
],
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"shufflenetv2_x0.5"
:
ShuffleNet_V2_X0_5_Weights
.
IMAGENET1K_V1
.
url
,
"shufflenetv2_x1.0"
:
ShuffleNet_V2_X1_0_Weights
.
IMAGENET1K_V1
.
url
,
"shufflenetv2_x1.5"
:
None
,
"shufflenetv2_x2.0"
:
None
,
}
)
torchvision/models/squeezenet.py
View file @
cc26cd81
...
@@ -109,7 +109,7 @@ def _squeezenet(
...
@@ -109,7 +109,7 @@ def _squeezenet(
model
=
SqueezeNet
(
version
,
**
kwargs
)
model
=
SqueezeNet
(
version
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -135,6 +135,8 @@ class SqueezeNet1_0_Weights(WeightsEnum):
...
@@ -135,6 +135,8 @@ class SqueezeNet1_0_Weights(WeightsEnum):
"acc@5"
:
80.420
,
"acc@5"
:
80.420
,
}
}
},
},
"_ops"
:
0.819
,
"_file_size"
:
4.778
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -154,6 +156,8 @@ class SqueezeNet1_1_Weights(WeightsEnum):
...
@@ -154,6 +156,8 @@ class SqueezeNet1_1_Weights(WeightsEnum):
"acc@5"
:
80.624
,
"acc@5"
:
80.624
,
}
}
},
},
"_ops"
:
0.349
,
"_file_size"
:
4.729
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -217,15 +221,3 @@ def squeezenet1_1(
...
@@ -217,15 +221,3 @@ def squeezenet1_1(
"""
"""
weights
=
SqueezeNet1_1_Weights
.
verify
(
weights
)
weights
=
SqueezeNet1_1_Weights
.
verify
(
weights
)
return
_squeezenet
(
"1_1"
,
weights
,
progress
,
**
kwargs
)
return
_squeezenet
(
"1_1"
,
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"squeezenet1_0"
:
SqueezeNet1_0_Weights
.
IMAGENET1K_V1
.
url
,
"squeezenet1_1"
:
SqueezeNet1_1_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/swin_transformer.py
View file @
cc26cd81
...
@@ -126,7 +126,8 @@ def shifted_window_attention(
...
@@ -126,7 +126,8 @@ def shifted_window_attention(
qkv_bias
:
Optional
[
Tensor
]
=
None
,
qkv_bias
:
Optional
[
Tensor
]
=
None
,
proj_bias
:
Optional
[
Tensor
]
=
None
,
proj_bias
:
Optional
[
Tensor
]
=
None
,
logit_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
logit_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
training
:
bool
=
True
,
)
->
Tensor
:
"""
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
It supports both of shifted and non-shifted window.
...
@@ -143,6 +144,7 @@ def shifted_window_attention(
...
@@ -143,6 +144,7 @@ def shifted_window_attention(
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention.
Tensor[N, H, W, C]: The output tensor after shifted window attention.
"""
"""
...
@@ -207,11 +209,11 @@ def shifted_window_attention(
...
@@ -207,11 +209,11 @@ def shifted_window_attention(
attn
=
attn
.
view
(
-
1
,
num_heads
,
x
.
size
(
1
),
x
.
size
(
1
))
attn
=
attn
.
view
(
-
1
,
num_heads
,
x
.
size
(
1
),
x
.
size
(
1
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
dropout
(
attn
,
p
=
attention_dropout
)
attn
=
F
.
dropout
(
attn
,
p
=
attention_dropout
,
training
=
training
)
x
=
attn
.
matmul
(
v
).
transpose
(
1
,
2
).
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
C
)
x
=
attn
.
matmul
(
v
).
transpose
(
1
,
2
).
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
C
)
x
=
F
.
linear
(
x
,
proj_weight
,
proj_bias
)
x
=
F
.
linear
(
x
,
proj_weight
,
proj_bias
)
x
=
F
.
dropout
(
x
,
p
=
dropout
)
x
=
F
.
dropout
(
x
,
p
=
dropout
,
training
=
training
)
# reverse windows
# reverse windows
x
=
x
.
view
(
B
,
pad_H
//
window_size
[
0
],
pad_W
//
window_size
[
1
],
window_size
[
0
],
window_size
[
1
],
C
)
x
=
x
.
view
(
B
,
pad_H
//
window_size
[
0
],
pad_W
//
window_size
[
1
],
window_size
[
0
],
window_size
[
1
],
C
)
...
@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module):
...
@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module):
self
.
relative_position_bias_table
,
self
.
relative_position_index
,
self
.
window_size
# type: ignore[arg-type]
self
.
relative_position_bias_table
,
self
.
relative_position_index
,
self
.
window_size
# type: ignore[arg-type]
)
)
def
forward
(
self
,
x
:
Tensor
):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
"""
"""
Args:
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
x (Tensor): Tensor with layout of [B, H, W, C]
...
@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module):
...
@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module):
dropout
=
self
.
dropout
,
dropout
=
self
.
dropout
,
qkv_bias
=
self
.
qkv
.
bias
,
qkv_bias
=
self
.
qkv
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
training
=
self
.
training
,
)
)
...
@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention):
...
@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention):
qkv_bias
=
self
.
qkv
.
bias
,
qkv_bias
=
self
.
qkv
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
logit_scale
=
self
.
logit_scale
,
logit_scale
=
self
.
logit_scale
,
training
=
self
.
training
,
)
)
...
@@ -494,6 +498,8 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
...
@@ -494,6 +498,8 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
)
)
def
forward
(
self
,
x
:
Tensor
):
def
forward
(
self
,
x
:
Tensor
):
# Here is the difference, we apply norm after the attention in V2.
# In V1 we applied norm before the attention.
x
=
x
+
self
.
stochastic_depth
(
self
.
norm1
(
self
.
attn
(
x
)))
x
=
x
+
self
.
stochastic_depth
(
self
.
norm1
(
self
.
attn
(
x
)))
x
=
x
+
self
.
stochastic_depth
(
self
.
norm2
(
self
.
mlp
(
x
)))
x
=
x
+
self
.
stochastic_depth
(
self
.
norm2
(
self
.
mlp
(
x
)))
return
x
return
x
...
@@ -502,7 +508,7 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
...
@@ -502,7 +508,7 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
class
SwinTransformer
(
nn
.
Module
):
class
SwinTransformer
(
nn
.
Module
):
"""
"""
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/
pdf
/2103.14030>`_ paper.
Shifted Windows" <https://arxiv.org/
abs
/2103.14030>`_ paper.
Args:
Args:
patch_size (List[int]): Patch size.
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
embed_dim (int): Patch embedding dimension.
...
@@ -587,7 +593,7 @@ class SwinTransformer(nn.Module):
...
@@ -587,7 +593,7 @@ class SwinTransformer(nn.Module):
num_features
=
embed_dim
*
2
**
(
len
(
depths
)
-
1
)
num_features
=
embed_dim
*
2
**
(
len
(
depths
)
-
1
)
self
.
norm
=
norm_layer
(
num_features
)
self
.
norm
=
norm_layer
(
num_features
)
self
.
permute
=
Permute
([
0
,
3
,
1
,
2
])
self
.
permute
=
Permute
([
0
,
3
,
1
,
2
])
# B H W C -> B C H W
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
flatten
=
nn
.
Flatten
(
1
)
self
.
flatten
=
nn
.
Flatten
(
1
)
self
.
head
=
nn
.
Linear
(
num_features
,
num_classes
)
self
.
head
=
nn
.
Linear
(
num_features
,
num_classes
)
...
@@ -633,7 +639,7 @@ def _swin_transformer(
...
@@ -633,7 +639,7 @@ def _swin_transformer(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -660,6 +666,8 @@ class Swin_T_Weights(WeightsEnum):
...
@@ -660,6 +666,8 @@ class Swin_T_Weights(WeightsEnum):
"acc@5"
:
95.776
,
"acc@5"
:
95.776
,
}
}
},
},
"_ops"
:
4.491
,
"_file_size"
:
108.19
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
},
)
)
...
@@ -683,6 +691,8 @@ class Swin_S_Weights(WeightsEnum):
...
@@ -683,6 +691,8 @@ class Swin_S_Weights(WeightsEnum):
"acc@5"
:
96.360
,
"acc@5"
:
96.360
,
}
}
},
},
"_ops"
:
8.741
,
"_file_size"
:
189.786
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
},
)
)
...
@@ -706,6 +716,8 @@ class Swin_B_Weights(WeightsEnum):
...
@@ -706,6 +716,8 @@ class Swin_B_Weights(WeightsEnum):
"acc@5"
:
96.640
,
"acc@5"
:
96.640
,
}
}
},
},
"_ops"
:
15.431
,
"_file_size"
:
335.364
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
},
)
)
...
@@ -729,6 +741,8 @@ class Swin_V2_T_Weights(WeightsEnum):
...
@@ -729,6 +741,8 @@ class Swin_V2_T_Weights(WeightsEnum):
"acc@5"
:
96.132
,
"acc@5"
:
96.132
,
}
}
},
},
"_ops"
:
5.94
,
"_file_size"
:
108.626
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
},
)
)
...
@@ -752,6 +766,8 @@ class Swin_V2_S_Weights(WeightsEnum):
...
@@ -752,6 +766,8 @@ class Swin_V2_S_Weights(WeightsEnum):
"acc@5"
:
96.816
,
"acc@5"
:
96.816
,
}
}
},
},
"_ops"
:
11.546
,
"_file_size"
:
190.675
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
},
)
)
...
@@ -775,6 +791,8 @@ class Swin_V2_B_Weights(WeightsEnum):
...
@@ -775,6 +791,8 @@ class Swin_V2_B_Weights(WeightsEnum):
"acc@5"
:
96.864
,
"acc@5"
:
96.864
,
}
}
},
},
"_ops"
:
20.325
,
"_file_size"
:
336.372
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
},
)
)
...
@@ -786,7 +804,7 @@ class Swin_V2_B_Weights(WeightsEnum):
...
@@ -786,7 +804,7 @@ class Swin_V2_B_Weights(WeightsEnum):
def
swin_t
(
*
,
weights
:
Optional
[
Swin_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
def
swin_t
(
*
,
weights
:
Optional
[
Swin_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
"""
Constructs a swin_tiny architecture from
Constructs a swin_tiny architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
pdf
/2103.14030>`_.
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
abs
/2103.14030>`_.
Args:
Args:
weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
...
@@ -824,7 +842,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
...
@@ -824,7 +842,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
def
swin_s
(
*
,
weights
:
Optional
[
Swin_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
def
swin_s
(
*
,
weights
:
Optional
[
Swin_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
"""
Constructs a swin_small architecture from
Constructs a swin_small architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
pdf
/2103.14030>`_.
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
abs
/2103.14030>`_.
Args:
Args:
weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
...
@@ -862,7 +880,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
...
@@ -862,7 +880,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
def
swin_b
(
*
,
weights
:
Optional
[
Swin_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
def
swin_b
(
*
,
weights
:
Optional
[
Swin_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
"""
Constructs a swin_base architecture from
Constructs a swin_base architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
pdf
/2103.14030>`_.
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
abs
/2103.14030>`_.
Args:
Args:
weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
...
@@ -900,7 +918,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
...
@@ -900,7 +918,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
def
swin_v2_t
(
*
,
weights
:
Optional
[
Swin_V2_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
def
swin_v2_t
(
*
,
weights
:
Optional
[
Swin_V2_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
"""
Constructs a swin_v2_tiny architecture from
Constructs a swin_v2_tiny architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
pdf
/2111.09883>`_.
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
abs
/2111.09883>`_.
Args:
Args:
weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
...
@@ -940,7 +958,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T
...
@@ -940,7 +958,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T
def
swin_v2_s
(
*
,
weights
:
Optional
[
Swin_V2_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
def
swin_v2_s
(
*
,
weights
:
Optional
[
Swin_V2_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
"""
Constructs a swin_v2_small architecture from
Constructs a swin_v2_small architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
pdf
/2111.09883>`_.
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
abs
/2111.09883>`_.
Args:
Args:
weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
...
@@ -980,7 +998,7 @@ def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = T
...
@@ -980,7 +998,7 @@ def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = T
def
swin_v2_b
(
*
,
weights
:
Optional
[
Swin_V2_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
def
swin_v2_b
(
*
,
weights
:
Optional
[
Swin_V2_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
"""
Constructs a swin_v2_base architecture from
Constructs a swin_v2_base architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
pdf
/2111.09883>`_.
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
abs
/2111.09883>`_.
Args:
Args:
weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
...
...
torchvision/models/vgg.py
View file @
cc26cd81
...
@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
...
@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_ovewrite_named_param
(
kwargs
,
"num_classes"
,
len
(
weights
.
meta
[
"categories"
]))
_ovewrite_named_param
(
kwargs
,
"num_classes"
,
len
(
weights
.
meta
[
"categories"
]))
model
=
VGG
(
make_layers
(
cfgs
[
cfg
],
batch_norm
=
batch_norm
),
**
kwargs
)
model
=
VGG
(
make_layers
(
cfgs
[
cfg
],
batch_norm
=
batch_norm
),
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -127,6 +127,8 @@ class VGG11_Weights(WeightsEnum):
...
@@ -127,6 +127,8 @@ class VGG11_Weights(WeightsEnum):
"acc@5"
:
88.628
,
"acc@5"
:
88.628
,
}
}
},
},
"_ops"
:
7.609
,
"_file_size"
:
506.84
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -145,6 +147,8 @@ class VGG11_BN_Weights(WeightsEnum):
...
@@ -145,6 +147,8 @@ class VGG11_BN_Weights(WeightsEnum):
"acc@5"
:
89.810
,
"acc@5"
:
89.810
,
}
}
},
},
"_ops"
:
7.609
,
"_file_size"
:
506.881
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -163,6 +167,8 @@ class VGG13_Weights(WeightsEnum):
...
@@ -163,6 +167,8 @@ class VGG13_Weights(WeightsEnum):
"acc@5"
:
89.246
,
"acc@5"
:
89.246
,
}
}
},
},
"_ops"
:
11.308
,
"_file_size"
:
507.545
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -181,6 +187,8 @@ class VGG13_BN_Weights(WeightsEnum):
...
@@ -181,6 +187,8 @@ class VGG13_BN_Weights(WeightsEnum):
"acc@5"
:
90.374
,
"acc@5"
:
90.374
,
}
}
},
},
"_ops"
:
11.308
,
"_file_size"
:
507.59
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -199,6 +207,8 @@ class VGG16_Weights(WeightsEnum):
...
@@ -199,6 +207,8 @@ class VGG16_Weights(WeightsEnum):
"acc@5"
:
90.382
,
"acc@5"
:
90.382
,
}
}
},
},
"_ops"
:
15.47
,
"_file_size"
:
527.796
,
},
},
)
)
IMAGENET1K_FEATURES
=
Weights
(
IMAGENET1K_FEATURES
=
Weights
(
...
@@ -221,6 +231,8 @@ class VGG16_Weights(WeightsEnum):
...
@@ -221,6 +231,8 @@ class VGG16_Weights(WeightsEnum):
"acc@5"
:
float
(
"nan"
),
"acc@5"
:
float
(
"nan"
),
}
}
},
},
"_ops"
:
15.47
,
"_file_size"
:
527.802
,
"_docs"
:
"""
"_docs"
:
"""
These weights can't be used for classification because they are missing values in the `classifier`
These weights can't be used for classification because they are missing values in the `classifier`
module. Only the `features` module has valid values and can be used for feature extraction. The weights
module. Only the `features` module has valid values and can be used for feature extraction. The weights
...
@@ -244,6 +256,8 @@ class VGG16_BN_Weights(WeightsEnum):
...
@@ -244,6 +256,8 @@ class VGG16_BN_Weights(WeightsEnum):
"acc@5"
:
91.516
,
"acc@5"
:
91.516
,
}
}
},
},
"_ops"
:
15.47
,
"_file_size"
:
527.866
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -262,6 +276,8 @@ class VGG19_Weights(WeightsEnum):
...
@@ -262,6 +276,8 @@ class VGG19_Weights(WeightsEnum):
"acc@5"
:
90.876
,
"acc@5"
:
90.876
,
}
}
},
},
"_ops"
:
19.632
,
"_file_size"
:
548.051
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -280,6 +296,8 @@ class VGG19_BN_Weights(WeightsEnum):
...
@@ -280,6 +296,8 @@ class VGG19_BN_Weights(WeightsEnum):
"acc@5"
:
91.842
,
"acc@5"
:
91.842
,
}
}
},
},
"_ops"
:
19.632
,
"_file_size"
:
548.143
,
},
},
)
)
DEFAULT
=
IMAGENET1K_V1
DEFAULT
=
IMAGENET1K_V1
...
@@ -491,21 +509,3 @@ def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = Tru
...
@@ -491,21 +509,3 @@ def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = Tru
weights
=
VGG19_BN_Weights
.
verify
(
weights
)
weights
=
VGG19_BN_Weights
.
verify
(
weights
)
return
_vgg
(
"E"
,
True
,
weights
,
progress
,
**
kwargs
)
return
_vgg
(
"E"
,
True
,
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"vgg11"
:
VGG11_Weights
.
IMAGENET1K_V1
.
url
,
"vgg13"
:
VGG13_Weights
.
IMAGENET1K_V1
.
url
,
"vgg16"
:
VGG16_Weights
.
IMAGENET1K_V1
.
url
,
"vgg19"
:
VGG19_Weights
.
IMAGENET1K_V1
.
url
,
"vgg11_bn"
:
VGG11_BN_Weights
.
IMAGENET1K_V1
.
url
,
"vgg13_bn"
:
VGG13_BN_Weights
.
IMAGENET1K_V1
.
url
,
"vgg16_bn"
:
VGG16_BN_Weights
.
IMAGENET1K_V1
.
url
,
"vgg19_bn"
:
VGG19_BN_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/video/__init__.py
View file @
cc26cd81
from
.mvit
import
*
from
.mvit
import
*
from
.resnet
import
*
from
.resnet
import
*
from
.s3d
import
*
from
.s3d
import
*
from
.swin_transformer
import
*
torchvision/models/video/mvit.py
View file @
cc26cd81
...
@@ -593,7 +593,7 @@ def _mvit(
...
@@ -593,7 +593,7 @@ def _mvit(
)
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -624,6 +624,8 @@ class MViT_V1_B_Weights(WeightsEnum):
...
@@ -624,6 +624,8 @@ class MViT_V1_B_Weights(WeightsEnum):
"acc@5"
:
93.582
,
"acc@5"
:
93.582
,
}
}
},
},
"_ops"
:
70.599
,
"_file_size"
:
139.764
,
},
},
)
)
DEFAULT
=
KINETICS400_V1
DEFAULT
=
KINETICS400_V1
...
@@ -655,6 +657,8 @@ class MViT_V2_S_Weights(WeightsEnum):
...
@@ -655,6 +657,8 @@ class MViT_V2_S_Weights(WeightsEnum):
"acc@5"
:
94.665
,
"acc@5"
:
94.665
,
}
}
},
},
"_ops"
:
64.224
,
"_file_size"
:
131.884
,
},
},
)
)
DEFAULT
=
KINETICS400_V1
DEFAULT
=
KINETICS400_V1
...
@@ -761,9 +765,10 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T
...
@@ -761,9 +765,10 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T
@
register_model
()
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
MViT_V2_S_Weights
.
KINETICS400_V1
))
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
MViT_V2_S_Weights
.
KINETICS400_V1
))
def
mvit_v2_s
(
*
,
weights
:
Optional
[
MViT_V2_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
MViT
:
def
mvit_v2_s
(
*
,
weights
:
Optional
[
MViT_V2_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
MViT
:
"""
"""Constructs a small MViTV2 architecture from
Constructs a small MViTV2 architecture from
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
`MViTv2: Improved Multiscale Vision Transformers for Classification
and Detection <https://arxiv.org/abs/2112.01526>`__.
.. betastatus:: video module
.. betastatus:: video module
...
@@ -781,7 +786,7 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T
...
@@ -781,7 +786,7 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T
for more details about this class.
for more details about this class.
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights
:members:
:members:
"""
"""
weights
=
MViT_V2_S_Weights
.
verify
(
weights
)
weights
=
MViT_V2_S_Weights
.
verify
(
weights
)
...
...
torchvision/models/video/resnet.py
View file @
cc26cd81
...
@@ -303,7 +303,7 @@ def _video_resnet(
...
@@ -303,7 +303,7 @@ def _video_resnet(
model
=
VideoResNet
(
block
,
conv_makers
,
layers
,
stem
,
**
kwargs
)
model
=
VideoResNet
(
block
,
conv_makers
,
layers
,
stem
,
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -332,6 +332,8 @@ class R3D_18_Weights(WeightsEnum):
...
@@ -332,6 +332,8 @@ class R3D_18_Weights(WeightsEnum):
"acc@5"
:
83.479
,
"acc@5"
:
83.479
,
}
}
},
},
"_ops"
:
40.697
,
"_file_size"
:
127.359
,
},
},
)
)
DEFAULT
=
KINETICS400_V1
DEFAULT
=
KINETICS400_V1
...
@@ -350,6 +352,8 @@ class MC3_18_Weights(WeightsEnum):
...
@@ -350,6 +352,8 @@ class MC3_18_Weights(WeightsEnum):
"acc@5"
:
84.130
,
"acc@5"
:
84.130
,
}
}
},
},
"_ops"
:
43.343
,
"_file_size"
:
44.672
,
},
},
)
)
DEFAULT
=
KINETICS400_V1
DEFAULT
=
KINETICS400_V1
...
@@ -368,6 +372,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
...
@@ -368,6 +372,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
"acc@5"
:
86.175
,
"acc@5"
:
86.175
,
}
}
},
},
"_ops"
:
40.519
,
"_file_size"
:
120.318
,
},
},
)
)
DEFAULT
=
KINETICS400_V1
DEFAULT
=
KINETICS400_V1
...
...
torchvision/models/video/s3d.py
View file @
cc26cd81
...
@@ -175,6 +175,8 @@ class S3D_Weights(WeightsEnum):
...
@@ -175,6 +175,8 @@ class S3D_Weights(WeightsEnum):
"acc@5"
:
88.050
,
"acc@5"
:
88.050
,
}
}
},
},
"_ops"
:
17.979
,
"_file_size"
:
31.972
,
},
},
)
)
DEFAULT
=
KINETICS400_V1
DEFAULT
=
KINETICS400_V1
...
@@ -212,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
...
@@ -212,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
model
=
S3D
(
**
kwargs
)
model
=
S3D
(
**
kwargs
)
if
weights
is
not
None
:
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
torchvision/models/video/swin_transformer.py
0 → 100644
View file @
cc26cd81
# Modified from 2d Swin Transformers in torchvision:
# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py
from
functools
import
partial
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
,
Tensor
from
...transforms._presets
import
VideoClassification
from
...utils
import
_log_api_usage_once
from
.._api
import
register_model
,
Weights
,
WeightsEnum
from
.._meta
import
_KINETICS400_CATEGORIES
from
.._utils
import
_ovewrite_named_param
,
handle_legacy_interface
from
..swin_transformer
import
PatchMerging
,
SwinTransformerBlock
__all__
=
[
"SwinTransformer3d"
,
"Swin3D_T_Weights"
,
"Swin3D_S_Weights"
,
"Swin3D_B_Weights"
,
"swin3d_t"
,
"swin3d_s"
,
"swin3d_b"
,
]
def
_get_window_and_shift_size
(
shift_size
:
List
[
int
],
size_dhw
:
List
[
int
],
window_size
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
for
i
in
range
(
3
):
if
size_dhw
[
i
]
<=
window_size
[
i
]:
# In this case, window_size will adapt to the input size, and no need to shift
window_size
[
i
]
=
size_dhw
[
i
]
shift_size
[
i
]
=
0
return
window_size
,
shift_size
torch
.
fx
.
wrap
(
"_get_window_and_shift_size"
)
def
_get_relative_position_bias
(
relative_position_bias_table
:
torch
.
Tensor
,
relative_position_index
:
torch
.
Tensor
,
window_size
:
List
[
int
]
)
->
Tensor
:
window_vol
=
window_size
[
0
]
*
window_size
[
1
]
*
window_size
[
2
]
# In 3d case we flatten the relative_position_bias
relative_position_bias
=
relative_position_bias_table
[
relative_position_index
[:
window_vol
,
:
window_vol
].
flatten
()
# type: ignore[index]
]
relative_position_bias
=
relative_position_bias
.
view
(
window_vol
,
window_vol
,
-
1
)
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
().
unsqueeze
(
0
)
return
relative_position_bias
torch
.
fx
.
wrap
(
"_get_relative_position_bias"
)
def
_compute_pad_size_3d
(
size_dhw
:
Tuple
[
int
,
int
,
int
],
patch_size
:
Tuple
[
int
,
int
,
int
])
->
Tuple
[
int
,
int
,
int
]:
pad_size
=
[(
patch_size
[
i
]
-
size_dhw
[
i
]
%
patch_size
[
i
])
%
patch_size
[
i
]
for
i
in
range
(
3
)]
return
pad_size
[
0
],
pad_size
[
1
],
pad_size
[
2
]
torch
.
fx
.
wrap
(
"_compute_pad_size_3d"
)
def
_compute_attention_mask_3d
(
x
:
Tensor
,
size_dhw
:
Tuple
[
int
,
int
,
int
],
window_size
:
Tuple
[
int
,
int
,
int
],
shift_size
:
Tuple
[
int
,
int
,
int
],
)
->
Tensor
:
# generate attention mask
attn_mask
=
x
.
new_zeros
(
*
size_dhw
)
num_windows
=
(
size_dhw
[
0
]
//
window_size
[
0
])
*
(
size_dhw
[
1
]
//
window_size
[
1
])
*
(
size_dhw
[
2
]
//
window_size
[
2
])
slices
=
[
(
(
0
,
-
window_size
[
i
]),
(
-
window_size
[
i
],
-
shift_size
[
i
]),
(
-
shift_size
[
i
],
None
),
)
for
i
in
range
(
3
)
]
count
=
0
for
d
in
slices
[
0
]:
for
h
in
slices
[
1
]:
for
w
in
slices
[
2
]:
attn_mask
[
d
[
0
]
:
d
[
1
],
h
[
0
]
:
h
[
1
],
w
[
0
]
:
w
[
1
]]
=
count
count
+=
1
# Partition window on attn_mask
attn_mask
=
attn_mask
.
view
(
size_dhw
[
0
]
//
window_size
[
0
],
window_size
[
0
],
size_dhw
[
1
]
//
window_size
[
1
],
window_size
[
1
],
size_dhw
[
2
]
//
window_size
[
2
],
window_size
[
2
],
)
attn_mask
=
attn_mask
.
permute
(
0
,
2
,
4
,
1
,
3
,
5
).
reshape
(
num_windows
,
window_size
[
0
]
*
window_size
[
1
]
*
window_size
[
2
]
)
attn_mask
=
attn_mask
.
unsqueeze
(
1
)
-
attn_mask
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
torch
.
fx
.
wrap
(
"_compute_attention_mask_3d"
)
def
shifted_window_attention_3d
(
input
:
Tensor
,
qkv_weight
:
Tensor
,
proj_weight
:
Tensor
,
relative_position_bias
:
Tensor
,
window_size
:
List
[
int
],
num_heads
:
int
,
shift_size
:
List
[
int
],
attention_dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
qkv_bias
:
Optional
[
Tensor
]
=
None
,
proj_bias
:
Optional
[
Tensor
]
=
None
,
training
:
bool
=
True
,
)
->
Tensor
:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
input (Tensor[B, T, H, W, C]): The input tensor, 5-dimensions.
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (List[int]): 3-dimensions window size, T, H, W .
num_heads (int): Number of attention heads.
shift_size (List[int]): Shift size for shifted window attention (T, H, W).
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Tensor[B, T, H, W, C]: The output tensor after shifted window attention.
"""
b
,
t
,
h
,
w
,
c
=
input
.
shape
# pad feature maps to multiples of window size
pad_size
=
_compute_pad_size_3d
((
t
,
h
,
w
),
(
window_size
[
0
],
window_size
[
1
],
window_size
[
2
]))
x
=
F
.
pad
(
input
,
(
0
,
0
,
0
,
pad_size
[
2
],
0
,
pad_size
[
1
],
0
,
pad_size
[
0
]))
_
,
tp
,
hp
,
wp
,
_
=
x
.
shape
padded_size
=
(
tp
,
hp
,
wp
)
# cyclic shift
if
sum
(
shift_size
)
>
0
:
x
=
torch
.
roll
(
x
,
shifts
=
(
-
shift_size
[
0
],
-
shift_size
[
1
],
-
shift_size
[
2
]),
dims
=
(
1
,
2
,
3
))
# partition windows
num_windows
=
(
(
padded_size
[
0
]
//
window_size
[
0
])
*
(
padded_size
[
1
]
//
window_size
[
1
])
*
(
padded_size
[
2
]
//
window_size
[
2
])
)
x
=
x
.
view
(
b
,
padded_size
[
0
]
//
window_size
[
0
],
window_size
[
0
],
padded_size
[
1
]
//
window_size
[
1
],
window_size
[
1
],
padded_size
[
2
]
//
window_size
[
2
],
window_size
[
2
],
c
,
)
x
=
x
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
,
6
,
7
).
reshape
(
b
*
num_windows
,
window_size
[
0
]
*
window_size
[
1
]
*
window_size
[
2
],
c
)
# B*nW, Wd*Wh*Ww, C
# multi-head attention
qkv
=
F
.
linear
(
x
,
qkv_weight
,
qkv_bias
)
qkv
=
qkv
.
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
3
,
num_heads
,
c
//
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
(
c
//
num_heads
)
**
-
0.5
attn
=
q
.
matmul
(
k
.
transpose
(
-
2
,
-
1
))
# add relative position bias
attn
=
attn
+
relative_position_bias
if
sum
(
shift_size
)
>
0
:
# generate attention mask to handle shifted windows with varying size
attn_mask
=
_compute_attention_mask_3d
(
x
,
(
padded_size
[
0
],
padded_size
[
1
],
padded_size
[
2
]),
(
window_size
[
0
],
window_size
[
1
],
window_size
[
2
]),
(
shift_size
[
0
],
shift_size
[
1
],
shift_size
[
2
]),
)
attn
=
attn
.
view
(
x
.
size
(
0
)
//
num_windows
,
num_windows
,
num_heads
,
x
.
size
(
1
),
x
.
size
(
1
))
attn
=
attn
+
attn_mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
num_heads
,
x
.
size
(
1
),
x
.
size
(
1
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
dropout
(
attn
,
p
=
attention_dropout
,
training
=
training
)
x
=
attn
.
matmul
(
v
).
transpose
(
1
,
2
).
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
c
)
x
=
F
.
linear
(
x
,
proj_weight
,
proj_bias
)
x
=
F
.
dropout
(
x
,
p
=
dropout
,
training
=
training
)
# reverse windows
x
=
x
.
view
(
b
,
padded_size
[
0
]
//
window_size
[
0
],
padded_size
[
1
]
//
window_size
[
1
],
padded_size
[
2
]
//
window_size
[
2
],
window_size
[
0
],
window_size
[
1
],
window_size
[
2
],
c
,
)
x
=
x
.
permute
(
0
,
1
,
4
,
2
,
5
,
3
,
6
,
7
).
reshape
(
b
,
tp
,
hp
,
wp
,
c
)
# reverse cyclic shift
if
sum
(
shift_size
)
>
0
:
x
=
torch
.
roll
(
x
,
shifts
=
(
shift_size
[
0
],
shift_size
[
1
],
shift_size
[
2
]),
dims
=
(
1
,
2
,
3
))
# unpad features
x
=
x
[:,
:
t
,
:
h
,
:
w
,
:].
contiguous
()
return
x
torch
.
fx
.
wrap
(
"shifted_window_attention_3d"
)
class
ShiftedWindowAttention3d
(
nn
.
Module
):
"""
See :func:`shifted_window_attention_3d`.
"""
def
__init__
(
self
,
dim
:
int
,
window_size
:
List
[
int
],
shift_size
:
List
[
int
],
num_heads
:
int
,
qkv_bias
:
bool
=
True
,
proj_bias
:
bool
=
True
,
attention_dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
)
->
None
:
super
().
__init__
()
if
len
(
window_size
)
!=
3
or
len
(
shift_size
)
!=
3
:
raise
ValueError
(
"window_size and shift_size must be of length 2"
)
self
.
window_size
=
window_size
# Wd, Wh, Ww
self
.
shift_size
=
shift_size
self
.
num_heads
=
num_heads
self
.
attention_dropout
=
attention_dropout
self
.
dropout
=
dropout
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
bias
=
proj_bias
)
self
.
define_relative_position_bias_table
()
self
.
define_relative_position_index
()
def
define_relative_position_bias_table
(
self
)
->
None
:
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
(
2
*
self
.
window_size
[
0
]
-
1
)
*
(
2
*
self
.
window_size
[
1
]
-
1
)
*
(
2
*
self
.
window_size
[
2
]
-
1
),
self
.
num_heads
,
)
)
# 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
nn
.
init
.
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
0.02
)
def
define_relative_position_index
(
self
)
->
None
:
# get pair-wise relative position index for each token inside the window
coords_dhw
=
[
torch
.
arange
(
self
.
window_size
[
i
])
for
i
in
range
(
3
)]
coords
=
torch
.
stack
(
torch
.
meshgrid
(
coords_dhw
[
0
],
coords_dhw
[
1
],
coords_dhw
[
2
],
indexing
=
"ij"
)
)
# 3, Wd, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 3, Wd*Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 3, Wd*Wh*Ww, Wd*Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wd*Wh*Ww, Wd*Wh*Ww, 3
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
2
]
+=
self
.
window_size
[
2
]
-
1
relative_coords
[:,
:,
0
]
*=
(
2
*
self
.
window_size
[
1
]
-
1
)
*
(
2
*
self
.
window_size
[
2
]
-
1
)
relative_coords
[:,
:,
1
]
*=
2
*
self
.
window_size
[
2
]
-
1
# We don't flatten the relative_position_index here in 3d case.
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wd*Wh*Ww, Wd*Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
def
get_relative_position_bias
(
self
,
window_size
:
List
[
int
])
->
torch
.
Tensor
:
return
_get_relative_position_bias
(
self
.
relative_position_bias_table
,
self
.
relative_position_index
,
window_size
)
# type: ignore
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
_
,
t
,
h
,
w
,
_
=
x
.
shape
size_dhw
=
[
t
,
h
,
w
]
window_size
,
shift_size
=
self
.
window_size
.
copy
(),
self
.
shift_size
.
copy
()
# Handle case where window_size is larger than the input tensor
window_size
,
shift_size
=
_get_window_and_shift_size
(
shift_size
,
size_dhw
,
window_size
)
relative_position_bias
=
self
.
get_relative_position_bias
(
window_size
)
return
shifted_window_attention_3d
(
x
,
self
.
qkv
.
weight
,
self
.
proj
.
weight
,
relative_position_bias
,
window_size
,
self
.
num_heads
,
shift_size
=
shift_size
,
attention_dropout
=
self
.
attention_dropout
,
dropout
=
self
.
dropout
,
qkv_bias
=
self
.
qkv
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
training
=
self
.
training
,
)
# Modified from:
# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py
class
PatchEmbed3d
(
nn
.
Module
):
"""Video to Patch Embedding.
Args:
patch_size (List[int]): Patch token size.
in_channels (int): Number of input channels. Default: 3
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
patch_size
:
List
[
int
],
in_channels
:
int
=
3
,
embed_dim
:
int
=
96
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
super
().
__init__
()
_log_api_usage_once
(
self
)
self
.
tuple_patch_size
=
(
patch_size
[
0
],
patch_size
[
1
],
patch_size
[
2
])
self
.
proj
=
nn
.
Conv3d
(
in_channels
,
embed_dim
,
kernel_size
=
self
.
tuple_patch_size
,
stride
=
self
.
tuple_patch_size
,
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
nn
.
Identity
()
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
"""Forward function."""
# padding
_
,
_
,
t
,
h
,
w
=
x
.
size
()
pad_size
=
_compute_pad_size_3d
((
t
,
h
,
w
),
self
.
tuple_patch_size
)
x
=
F
.
pad
(
x
,
(
0
,
pad_size
[
2
],
0
,
pad_size
[
1
],
0
,
pad_size
[
0
]))
x
=
self
.
proj
(
x
)
# B C T Wh Ww
x
=
x
.
permute
(
0
,
2
,
3
,
4
,
1
)
# B T Wh Ww C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
class
SwinTransformer3d
(
nn
.
Module
):
"""
Implements 3D Swin Transformer from the `"Video Swin Transformer" <https://arxiv.org/abs/2106.13230>`_ paper.
Args:
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
num_classes (int): Number of classes for classification head. Default: 400.
norm_layer (nn.Module, optional): Normalization layer. Default: None.
block (nn.Module, optional): SwinTransformer Block. Default: None.
downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
patch_embed (nn.Module, optional): Patch Embedding layer. Default: None.
"""
def
__init__
(
self
,
patch_size
:
List
[
int
],
embed_dim
:
int
,
depths
:
List
[
int
],
num_heads
:
List
[
int
],
window_size
:
List
[
int
],
mlp_ratio
:
float
=
4.0
,
dropout
:
float
=
0.0
,
attention_dropout
:
float
=
0.0
,
stochastic_depth_prob
:
float
=
0.1
,
num_classes
:
int
=
400
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
block
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
downsample_layer
:
Callable
[...,
nn
.
Module
]
=
PatchMerging
,
patch_embed
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
super
().
__init__
()
_log_api_usage_once
(
self
)
self
.
num_classes
=
num_classes
if
block
is
None
:
block
=
partial
(
SwinTransformerBlock
,
attn_layer
=
ShiftedWindowAttention3d
)
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-5
)
if
patch_embed
is
None
:
patch_embed
=
PatchEmbed3d
# split image into non-overlapping patches
self
.
patch_embed
=
patch_embed
(
patch_size
=
patch_size
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
dropout
)
layers
:
List
[
nn
.
Module
]
=
[]
total_stage_blocks
=
sum
(
depths
)
stage_block_id
=
0
# build SwinTransformer blocks
for
i_stage
in
range
(
len
(
depths
)):
stage
:
List
[
nn
.
Module
]
=
[]
dim
=
embed_dim
*
2
**
i_stage
for
i_layer
in
range
(
depths
[
i_stage
]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob
=
stochastic_depth_prob
*
float
(
stage_block_id
)
/
(
total_stage_blocks
-
1
)
stage
.
append
(
block
(
dim
,
num_heads
[
i_stage
],
window_size
=
window_size
,
shift_size
=
[
0
if
i_layer
%
2
==
0
else
w
//
2
for
w
in
window_size
],
mlp_ratio
=
mlp_ratio
,
dropout
=
dropout
,
attention_dropout
=
attention_dropout
,
stochastic_depth_prob
=
sd_prob
,
norm_layer
=
norm_layer
,
attn_layer
=
ShiftedWindowAttention3d
,
)
)
stage_block_id
+=
1
layers
.
append
(
nn
.
Sequential
(
*
stage
))
# add patch merging layer
if
i_stage
<
(
len
(
depths
)
-
1
):
layers
.
append
(
downsample_layer
(
dim
,
norm_layer
))
self
.
features
=
nn
.
Sequential
(
*
layers
)
self
.
num_features
=
embed_dim
*
2
**
(
len
(
depths
)
-
1
)
self
.
norm
=
norm_layer
(
self
.
num_features
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool3d
(
1
)
self
.
head
=
nn
.
Linear
(
self
.
num_features
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
m
.
bias
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# x: B C T H W
x
=
self
.
patch_embed
(
x
)
# B _T _H _W C
x
=
self
.
pos_drop
(
x
)
x
=
self
.
features
(
x
)
# B _T _H _W C
x
=
self
.
norm
(
x
)
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
)
# B, C, _T, _H, _W
x
=
self
.
avgpool
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
head
(
x
)
return
x
def
_swin_transformer3d
(
patch_size
:
List
[
int
],
embed_dim
:
int
,
depths
:
List
[
int
],
num_heads
:
List
[
int
],
window_size
:
List
[
int
],
stochastic_depth_prob
:
float
,
weights
:
Optional
[
WeightsEnum
],
progress
:
bool
,
**
kwargs
:
Any
,
)
->
SwinTransformer3d
:
if
weights
is
not
None
:
_ovewrite_named_param
(
kwargs
,
"num_classes"
,
len
(
weights
.
meta
[
"categories"
]))
model
=
SwinTransformer3d
(
patch_size
=
patch_size
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
window_size
,
stochastic_depth_prob
=
stochastic_depth_prob
,
**
kwargs
,
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
_COMMON_META
=
{
"categories"
:
_KINETICS400_CATEGORIES
,
"min_size"
:
(
1
,
1
),
"min_temporal_size"
:
1
,
}
class
Swin3D_T_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_t-7615ae03.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
28158070
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
77.715
,
"acc@5"
:
93.519
,
}
},
"_ops"
:
43.882
,
"_file_size"
:
121.543
,
},
)
DEFAULT
=
KINETICS400_V1
class
Swin3D_S_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_s-da41c237.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
49816678
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
79.521
,
"acc@5"
:
94.158
,
}
},
"_ops"
:
82.841
,
"_file_size"
:
218.288
,
},
)
DEFAULT
=
KINETICS400_V1
class
Swin3D_B_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
88048984
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
79.427
,
"acc@5"
:
94.386
,
}
},
"_ops"
:
140.667
,
"_file_size"
:
364.134
,
},
)
KINETICS400_IMAGENET22K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
88048984
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
81.643
,
"acc@5"
:
95.574
,
}
},
"_ops"
:
140.667
,
"_file_size"
:
364.134
,
},
)
DEFAULT
=
KINETICS400_V1
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
Swin3D_T_Weights
.
KINETICS400_V1
))
def
swin3d_t
(
*
,
weights
:
Optional
[
Swin3D_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer3d
:
"""
Constructs a swin_tiny architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_T_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_T_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_T_Weights
:members:
"""
weights
=
Swin3D_T_Weights
.
verify
(
weights
)
return
_swin_transformer3d
(
patch_size
=
[
2
,
4
,
4
],
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
[
8
,
7
,
7
],
stochastic_depth_prob
=
0.1
,
weights
=
weights
,
progress
=
progress
,
**
kwargs
,
)
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
Swin3D_S_Weights
.
KINETICS400_V1
))
def
swin3d_s
(
*
,
weights
:
Optional
[
Swin3D_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer3d
:
"""
Constructs a swin_small architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_S_Weights
:members:
"""
weights
=
Swin3D_S_Weights
.
verify
(
weights
)
return
_swin_transformer3d
(
patch_size
=
[
2
,
4
,
4
],
embed_dim
=
96
,
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
[
8
,
7
,
7
],
stochastic_depth_prob
=
0.1
,
weights
=
weights
,
progress
=
progress
,
**
kwargs
,
)
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
Swin3D_B_Weights
.
KINETICS400_V1
))
def
swin3d_b
(
*
,
weights
:
Optional
[
Swin3D_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer3d
:
"""
Constructs a swin_base architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_B_Weights
:members:
"""
weights
=
Swin3D_B_Weights
.
verify
(
weights
)
return
_swin_transformer3d
(
patch_size
=
[
2
,
4
,
4
],
embed_dim
=
128
,
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
4
,
8
,
16
,
32
],
window_size
=
[
8
,
7
,
7
],
stochastic_depth_prob
=
0.1
,
weights
=
weights
,
progress
=
progress
,
**
kwargs
,
)
torchvision/models/vision_transformer.py
View file @
cc26cd81
...
@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module):
...
@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module):
def
forward
(
self
,
input
:
torch
.
Tensor
):
def
forward
(
self
,
input
:
torch
.
Tensor
):
torch
.
_assert
(
input
.
dim
()
==
3
,
f
"Expected (batch_size, seq_length, hidden_dim) got
{
input
.
shape
}
"
)
torch
.
_assert
(
input
.
dim
()
==
3
,
f
"Expected (batch_size, seq_length, hidden_dim) got
{
input
.
shape
}
"
)
x
=
self
.
ln_1
(
input
)
x
=
self
.
ln_1
(
input
)
x
,
_
=
self
.
self_attention
(
query
=
x
,
key
=
x
,
value
=
x
,
need_weights
=
False
)
x
,
_
=
self
.
self_attention
(
x
,
x
,
x
,
need_weights
=
False
)
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
x
=
x
+
input
x
=
x
+
input
...
@@ -332,7 +332,7 @@ def _vision_transformer(
...
@@ -332,7 +332,7 @@ def _vision_transformer(
)
)
if
weights
:
if
weights
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
return
model
...
@@ -363,6 +363,8 @@ class ViT_B_16_Weights(WeightsEnum):
...
@@ -363,6 +363,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5"
:
95.318
,
"acc@5"
:
95.318
,
}
}
},
},
"_ops"
:
17.564
,
"_file_size"
:
330.285
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained from scratch by using a modified version of `DeIT
These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
...
@@ -387,6 +389,8 @@ class ViT_B_16_Weights(WeightsEnum):
...
@@ -387,6 +389,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5"
:
97.650
,
"acc@5"
:
97.650
,
}
}
},
},
"_ops"
:
55.484
,
"_file_size"
:
331.398
,
"_docs"
:
"""
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
@@ -412,6 +416,8 @@ class ViT_B_16_Weights(WeightsEnum):
...
@@ -412,6 +416,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5"
:
96.180
,
"acc@5"
:
96.180
,
}
}
},
},
"_ops"
:
17.564
,
"_file_size"
:
330.285
,
"_docs"
:
"""
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
@@ -436,6 +442,8 @@ class ViT_B_32_Weights(WeightsEnum):
...
@@ -436,6 +442,8 @@ class ViT_B_32_Weights(WeightsEnum):
"acc@5"
:
92.466
,
"acc@5"
:
92.466
,
}
}
},
},
"_ops"
:
4.409
,
"_file_size"
:
336.604
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained from scratch by using a modified version of `DeIT
These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
...
@@ -460,6 +468,8 @@ class ViT_L_16_Weights(WeightsEnum):
...
@@ -460,6 +468,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5"
:
94.638
,
"acc@5"
:
94.638
,
}
}
},
},
"_ops"
:
61.555
,
"_file_size"
:
1161.023
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained from scratch by using a modified version of TorchVision's
These weights were trained from scratch by using a modified version of TorchVision's
`new training recipe
`new training recipe
...
@@ -485,6 +495,8 @@ class ViT_L_16_Weights(WeightsEnum):
...
@@ -485,6 +495,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5"
:
98.512
,
"acc@5"
:
98.512
,
}
}
},
},
"_ops"
:
361.986
,
"_file_size"
:
1164.258
,
"_docs"
:
"""
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
@@ -510,6 +522,8 @@ class ViT_L_16_Weights(WeightsEnum):
...
@@ -510,6 +522,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5"
:
97.422
,
"acc@5"
:
97.422
,
}
}
},
},
"_ops"
:
61.555
,
"_file_size"
:
1161.023
,
"_docs"
:
"""
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
@@ -534,6 +548,8 @@ class ViT_L_32_Weights(WeightsEnum):
...
@@ -534,6 +548,8 @@ class ViT_L_32_Weights(WeightsEnum):
"acc@5"
:
93.07
,
"acc@5"
:
93.07
,
}
}
},
},
"_ops"
:
15.378
,
"_file_size"
:
1169.449
,
"_docs"
:
"""
"_docs"
:
"""
These weights were trained from scratch by using a modified version of `DeIT
These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
...
@@ -562,6 +578,8 @@ class ViT_H_14_Weights(WeightsEnum):
...
@@ -562,6 +578,8 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5"
:
98.694
,
"acc@5"
:
98.694
,
}
}
},
},
"_ops"
:
1016.717
,
"_file_size"
:
2416.643
,
"_docs"
:
"""
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
@@ -587,6 +605,8 @@ class ViT_H_14_Weights(WeightsEnum):
...
@@ -587,6 +605,8 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5"
:
97.730
,
"acc@5"
:
97.730
,
}
}
},
},
"_ops"
:
167.295
,
"_file_size"
:
2411.209
,
"_docs"
:
"""
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
@@ -773,7 +793,7 @@ def interpolate_embeddings(
...
@@ -773,7 +793,7 @@ def interpolate_embeddings(
interpolation_mode
:
str
=
"bicubic"
,
interpolation_mode
:
str
=
"bicubic"
,
reset_heads
:
bool
=
False
,
reset_heads
:
bool
=
False
,
)
->
"OrderedDict[str, torch.Tensor]"
:
)
->
"OrderedDict[str, torch.Tensor]"
:
"""This function helps interpolat
ing
positional embeddings during checkpoint loading,
"""This function helps interpolat
e
positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution.
especially when you want to apply a pre-trained model on images with different resolution.
Args:
Args:
...
@@ -798,7 +818,7 @@ def interpolate_embeddings(
...
@@ -798,7 +818,7 @@ def interpolate_embeddings(
# We do this by reshaping the positions embeddings to a 2d grid, performing
# We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if
new_seq_length
!=
seq_length
:
if
new_seq_length
!=
seq_length
:
# The class token embedding shouldn't be interpolated so we split it up.
# The class token embedding shouldn't be interpolated
,
so we split it up.
seq_length
-=
1
seq_length
-=
1
new_seq_length
-=
1
new_seq_length
-=
1
pos_embedding_token
=
pos_embedding
[:,
:
1
,
:]
pos_embedding_token
=
pos_embedding
[:,
:
1
,
:]
...
@@ -842,17 +862,3 @@ def interpolate_embeddings(
...
@@ -842,17 +862,3 @@ def interpolate_embeddings(
model_state
=
model_state_copy
model_state
=
model_state_copy
return
model_state
return
model_state
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"vit_b_16"
:
ViT_B_16_Weights
.
IMAGENET1K_V1
.
url
,
"vit_b_32"
:
ViT_B_32_Weights
.
IMAGENET1K_V1
.
url
,
"vit_l_16"
:
ViT_L_16_Weights
.
IMAGENET1K_V1
.
url
,
"vit_l_32"
:
ViT_L_32_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/ops/_box_convert.py
View file @
cc26cd81
...
@@ -50,7 +50,7 @@ def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
...
@@ -50,7 +50,7 @@ def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
def
_box_xywh_to_xyxy
(
boxes
:
Tensor
)
->
Tensor
:
def
_box_xywh_to_xyxy
(
boxes
:
Tensor
)
->
Tensor
:
"""
"""
Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format.
Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format.
(x, y) refers to top left of bouding box.
(x, y) refers to top left of bou
n
ding box.
(w, h) refers to width and height of box.
(w, h) refers to width and height of box.
Args:
Args:
boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted.
boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted.
...
...
torchvision/ops/_register_onnx_ops.py
View file @
cc26cd81
...
@@ -2,65 +2,106 @@ import sys
...
@@ -2,65 +2,106 @@ import sys
import
warnings
import
warnings
import
torch
import
torch
from
torch.onnx
import
symbolic_opset11
as
opset11
from
torch.onnx.symbolic_helper
import
parse_args
_onnx_opset_version
=
11
_ONNX_OPSET_VERSION_11
=
11
_ONNX_OPSET_VERSION_16
=
16
BASE_ONNX_OPSET_VERSION
=
_ONNX_OPSET_VERSION_11
def
_register_custom_op
():
@
parse_args
(
"v"
,
"v"
,
"f"
)
from
torch.onnx.symbolic_helper
import
parse_args
def
symbolic_multi_label_nms
(
g
,
boxes
,
scores
,
iou_threshold
):
from
torch.onnx.symbolic_opset11
import
select
,
squeeze
,
unsqueeze
boxes
=
opset11
.
unsqueeze
(
g
,
boxes
,
0
)
from
torch.onnx.symbolic_opset9
import
_cast_Long
scores
=
opset11
.
unsqueeze
(
g
,
opset11
.
unsqueeze
(
g
,
scores
,
0
),
0
)
max_output_per_class
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
sys
.
maxsize
],
dtype
=
torch
.
long
))
@
parse_args
(
"v"
,
"v"
,
"f"
)
iou_threshold
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
iou_threshold
],
dtype
=
torch
.
float
))
def
symbolic_multi_label_nms
(
g
,
boxes
,
scores
,
iou_threshold
):
boxes
=
unsqueeze
(
g
,
boxes
,
0
)
# Cast boxes and scores to float32 in case they are float64 inputs
scores
=
unsqueeze
(
g
,
unsqueeze
(
g
,
scores
,
0
),
0
)
nms_out
=
g
.
op
(
max_output_per_class
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
sys
.
maxsize
],
dtype
=
torch
.
long
))
"NonMaxSuppression"
,
iou_threshold
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
iou_threshold
],
dtype
=
torch
.
float
))
g
.
op
(
"Cast"
,
boxes
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
FLOAT
),
nms_out
=
g
.
op
(
"NonMaxSuppression"
,
boxes
,
scores
,
max_output_per_class
,
iou_threshold
)
g
.
op
(
"Cast"
,
scores
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
FLOAT
),
return
squeeze
(
g
,
select
(
g
,
nms_out
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
))),
1
)
max_output_per_class
,
iou_threshold
,
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
)
def
roi_align
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
return
opset11
.
squeeze
(
batch_indices
=
_cast_Long
(
g
,
opset11
.
select
(
g
,
nms_out
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
))),
1
g
,
squeeze
(
g
,
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
),
False
)
)
rois
=
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
# TODO: Remove this warning after ONNX opset 16 is supported.
def
_process_batch_indices_for_roi_align
(
g
,
rois
):
if
aligned
:
indices
=
opset11
.
squeeze
(
warnings
.
warn
(
g
,
opset11
.
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
)
"The workaround is that the user need apply the patch "
return
g
.
op
(
"Cast"
,
indices
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
INT64
)
"https://github.com/microsoft/onnxruntime/pull/8564 "
"and build ONNXRuntime from source."
)
def
_process_rois_for_roi_align
(
g
,
rois
):
return
opset11
.
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
# ONNX doesn't support negative sampling_ratio
if
sampling_ratio
<
0
:
warnings
.
warn
(
def
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
:
int
):
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported."
if
sampling_ratio
<
0
:
)
warnings
.
warn
(
sampling_ratio
=
0
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
return
g
.
op
(
"The model will be exported with a sampling_ratio of 0."
"RoiAlign"
,
input
,
rois
,
batch_indices
,
spatial_scale_f
=
spatial_scale
,
output_height_i
=
pooled_height
,
output_width_i
=
pooled_width
,
sampling_ratio_i
=
sampling_ratio
,
)
)
sampling_ratio
=
0
return
sampling_ratio
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_pool
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
):
def
roi_align_opset11
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
roi_pool
=
g
.
op
(
batch_indices
=
_process_batch_indices_for_roi_align
(
g
,
rois
)
"MaxRoiPool"
,
input
,
rois
,
pooled_shape_i
=
(
pooled_height
,
pooled_width
),
spatial_scale_f
=
spatial_scale
rois
=
_process_rois_for_roi_align
(
g
,
rois
)
if
aligned
:
warnings
.
warn
(
"ROIAlign with aligned=True is only supported in opset >= 16. "
"Please export with opset 16 or higher, or use aligned=False."
)
)
return
roi_pool
,
None
sampling_ratio
=
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
)
return
g
.
op
(
"RoiAlign"
,
input
,
rois
,
batch_indices
,
spatial_scale_f
=
spatial_scale
,
output_height_i
=
pooled_height
,
output_width_i
=
pooled_width
,
sampling_ratio_i
=
sampling_ratio
,
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align_opset16
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
batch_indices
=
_process_batch_indices_for_roi_align
(
g
,
rois
)
rois
=
_process_rois_for_roi_align
(
g
,
rois
)
coordinate_transformation_mode
=
"half_pixel"
if
aligned
else
"output_half_pixel"
sampling_ratio
=
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
)
return
g
.
op
(
"RoiAlign"
,
input
,
rois
,
batch_indices
,
coordinate_transformation_mode_s
=
coordinate_transformation_mode
,
spatial_scale_f
=
spatial_scale
,
output_height_i
=
pooled_height
,
output_width_i
=
pooled_width
,
sampling_ratio_i
=
sampling_ratio
,
)
from
torch.onnx
import
register_custom_op_symbolic
register_custom_op_symbolic
(
"torchvision::nms"
,
symbolic_multi_label_nms
,
_onnx_opset_version
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
)
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align
,
_onnx_opset_version
)
def
roi_pool
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
):
register_custom_op_symbolic
(
"torchvision::roi_pool"
,
roi_pool
,
_onnx_opset_version
)
roi_pool
=
g
.
op
(
"MaxRoiPool"
,
input
,
rois
,
pooled_shape_i
=
(
pooled_height
,
pooled_width
),
spatial_scale_f
=
spatial_scale
)
return
roi_pool
,
None
def
_register_custom_op
():
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::nms"
,
symbolic_multi_label_nms
,
_ONNX_OPSET_VERSION_11
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset11
,
_ONNX_OPSET_VERSION_11
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset16
,
_ONNX_OPSET_VERSION_16
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_pool"
,
roi_pool
,
_ONNX_OPSET_VERSION_11
)
torchvision/ops/ciou_loss.py
View file @
cc26cd81
...
@@ -63,9 +63,16 @@ def complete_box_iou_loss(
...
@@ -63,9 +63,16 @@ def complete_box_iou_loss(
alpha
=
v
/
(
1
-
iou
+
v
+
eps
)
alpha
=
v
/
(
1
-
iou
+
v
+
eps
)
loss
=
diou_loss
+
alpha
*
v
loss
=
diou_loss
+
alpha
*
v
if
reduction
==
"mean"
:
# Check reduction option and return loss accordingly
if
reduction
==
"none"
:
pass
elif
reduction
==
"mean"
:
loss
=
loss
.
mean
()
if
loss
.
numel
()
>
0
else
0.0
*
loss
.
sum
()
loss
=
loss
.
mean
()
if
loss
.
numel
()
>
0
else
0.0
*
loss
.
sum
()
elif
reduction
==
"sum"
:
elif
reduction
==
"sum"
:
loss
=
loss
.
sum
()
loss
=
loss
.
sum
()
else
:
raise
ValueError
(
f
"Invalid Value for arg 'reduction': '
{
reduction
}
\n
Supported reduction modes: 'none', 'mean', 'sum'"
)
return
loss
return
loss
torchvision/ops/deform_conv.py
View file @
cc26cd81
...
@@ -68,7 +68,7 @@ def deform_conv2d(
...
@@ -68,7 +68,7 @@ def deform_conv2d(
use_mask
=
mask
is
not
None
use_mask
=
mask
is
not
None
if
mask
is
None
:
if
mask
is
None
:
mask
=
torch
.
zeros
((
input
.
shape
[
0
],
0
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
mask
=
torch
.
zeros
((
input
.
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
if
bias
is
None
:
if
bias
is
None
:
bias
=
torch
.
zeros
(
out_channels
,
device
=
input
.
device
,
dtype
=
input
.
dtype
)
bias
=
torch
.
zeros
(
out_channels
,
device
=
input
.
device
,
dtype
=
input
.
dtype
)
...
...
Prev
1
…
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment