Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
cc26cd81
Commit
cc26cd81
authored
Nov 27, 2023
by
panning
Browse files
merge v0.16.0
parents
f78f29f5
fbb4cc54
Changes
371
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1098 additions
and
253 deletions
+1098
-253
torchvision/models/quantization/shufflenetv2.py
torchvision/models/quantization/shufflenetv2.py
+10
-15
torchvision/models/regnet.py
torchvision/models/regnet.py
+70
-26
torchvision/models/resnet.py
torchvision/models/resnet.py
+41
-26
torchvision/models/segmentation/deeplabv3.py
torchvision/models/segmentation/deeplabv3.py
+9
-16
torchvision/models/segmentation/fcn.py
torchvision/models/segmentation/fcn.py
+6
-14
torchvision/models/segmentation/lraspp.py
torchvision/models/segmentation/lraspp.py
+3
-12
torchvision/models/shufflenetv2.py
torchvision/models/shufflenetv2.py
+10
-16
torchvision/models/squeezenet.py
torchvision/models/squeezenet.py
+5
-13
torchvision/models/swin_transformer.py
torchvision/models/swin_transformer.py
+31
-13
torchvision/models/vgg.py
torchvision/models/vgg.py
+19
-19
torchvision/models/video/__init__.py
torchvision/models/video/__init__.py
+1
-0
torchvision/models/video/mvit.py
torchvision/models/video/mvit.py
+10
-5
torchvision/models/video/resnet.py
torchvision/models/video/resnet.py
+7
-1
torchvision/models/video/s3d.py
torchvision/models/video/s3d.py
+3
-1
torchvision/models/video/swin_transformer.py
torchvision/models/video/swin_transformer.py
+743
-0
torchvision/models/vision_transformer.py
torchvision/models/vision_transformer.py
+24
-18
torchvision/ops/_box_convert.py
torchvision/ops/_box_convert.py
+1
-1
torchvision/ops/_register_onnx_ops.py
torchvision/ops/_register_onnx_ops.py
+95
-54
torchvision/ops/ciou_loss.py
torchvision/ops/ciou_loss.py
+9
-2
torchvision/ops/deform_conv.py
torchvision/ops/deform_conv.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
371 of 371+
files are displayed.
Plain diff
Email patch
torchvision/models/quantization/shufflenetv2.py
View file @
cc26cd81
...
...
@@ -108,7 +108,7 @@ def _shufflenetv2(
quantize_model
(
model
,
backend
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -139,6 +139,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
"acc@5"
:
79.780
,
}
},
"_ops"
:
0.04
,
"_file_size"
:
1.501
,
},
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
...
@@ -146,7 +148,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
class
ShuffleNet_V2_X1_0_QuantizedWeights
(
WeightsEnum
):
IMAGENET1K_FBGEMM_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-
db332c57
.pth"
,
url
=
"https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-
1e62bb32
.pth"
,
transforms
=
partial
(
ImageClassification
,
crop_size
=
224
),
meta
=
{
**
_COMMON_META
,
...
...
@@ -158,6 +160,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
"acc@5"
:
87.582
,
}
},
"_ops"
:
0.145
,
"_file_size"
:
2.334
,
},
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
...
@@ -178,6 +182,8 @@ class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum):
"acc@5"
:
90.700
,
}
},
"_ops"
:
0.296
,
"_file_size"
:
3.672
,
},
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
...
@@ -198,6 +204,8 @@ class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
"acc@5"
:
92.488
,
}
},
"_ops"
:
0.583
,
"_file_size"
:
7.467
,
},
)
DEFAULT
=
IMAGENET1K_FBGEMM_V1
...
...
@@ -417,16 +425,3 @@ def shufflenet_v2_x2_0(
return
_shufflenetv2
(
[
4
,
8
,
4
],
[
24
,
244
,
488
,
976
,
2048
],
weights
=
weights
,
progress
=
progress
,
quantize
=
quantize
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
from
..shufflenetv2
import
model_urls
# noqa: F401
quant_model_urls
=
_ModelURLs
(
{
"shufflenetv2_x0.5_fbgemm"
:
ShuffleNet_V2_X0_5_QuantizedWeights
.
IMAGENET1K_FBGEMM_V1
.
url
,
"shufflenetv2_x1.0_fbgemm"
:
ShuffleNet_V2_X1_0_QuantizedWeights
.
IMAGENET1K_FBGEMM_V1
.
url
,
}
)
torchvision/models/regnet.py
View file @
cc26cd81
...
...
@@ -212,7 +212,7 @@ class BlockParams:
**
kwargs
:
Any
,
)
->
"BlockParams"
:
"""
Programatically compute all the per-block settings,
Program
m
atically compute all the per-block settings,
given the RegNet parameters.
The first step is to compute the quantized linear block parameters,
...
...
@@ -397,7 +397,7 @@ def _regnet(
model
=
RegNet
(
block_params
,
norm_layer
=
norm_layer
,
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -428,6 +428,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@5"
:
91.716
,
}
},
"_ops"
:
0.402
,
"_file_size"
:
16.806
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -444,6 +446,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@5"
:
92.742
,
}
},
"_ops"
:
0.402
,
"_file_size"
:
16.806
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -468,6 +472,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@5"
:
93.136
,
}
},
"_ops"
:
0.834
,
"_file_size"
:
24.774
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -484,6 +490,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@5"
:
94.502
,
}
},
"_ops"
:
0.834
,
"_file_size"
:
24.774
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -508,6 +516,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@5"
:
93.966
,
}
},
"_ops"
:
1.612
,
"_file_size"
:
43.152
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -524,6 +534,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@5"
:
95.444
,
}
},
"_ops"
:
1.612
,
"_file_size"
:
43.152
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -548,6 +560,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@5"
:
94.576
,
}
},
"_ops"
:
3.176
,
"_file_size"
:
74.567
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -564,6 +578,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@5"
:
95.972
,
}
},
"_ops"
:
3.176
,
"_file_size"
:
74.567
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -588,6 +604,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@5"
:
95.048
,
}
},
"_ops"
:
8.473
,
"_file_size"
:
150.701
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -604,6 +622,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@5"
:
96.330
,
}
},
"_ops"
:
8.473
,
"_file_size"
:
150.701
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -628,6 +648,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
95.240
,
}
},
"_ops"
:
15.912
,
"_file_size"
:
319.49
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -644,6 +666,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
96.328
,
}
},
"_ops"
:
15.912
,
"_file_size"
:
319.49
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -665,6 +689,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
98.054
,
}
},
"_ops"
:
46.735
,
"_file_size"
:
319.49
,
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
...
@@ -686,6 +712,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5"
:
97.244
,
}
},
"_ops"
:
15.912
,
"_file_size"
:
319.49
,
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
...
@@ -709,6 +737,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
95.340
,
}
},
"_ops"
:
32.28
,
"_file_size"
:
554.076
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -725,6 +755,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
96.498
,
}
},
"_ops"
:
32.28
,
"_file_size"
:
554.076
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -746,6 +778,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
98.362
,
}
},
"_ops"
:
94.826
,
"_file_size"
:
554.076
,
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
...
@@ -767,6 +801,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5"
:
97.480
,
}
},
"_ops"
:
32.28
,
"_file_size"
:
554.076
,
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
...
@@ -791,6 +827,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5"
:
98.682
,
}
},
"_ops"
:
374.57
,
"_file_size"
:
2461.564
,
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
...
@@ -812,6 +850,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5"
:
97.844
,
}
},
"_ops"
:
127.518
,
"_file_size"
:
2461.564
,
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
...
@@ -835,6 +875,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@5"
:
90.950
,
}
},
"_ops"
:
0.414
,
"_file_size"
:
21.258
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -851,6 +893,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@5"
:
92.322
,
}
},
"_ops"
:
0.414
,
"_file_size"
:
21.257
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -875,6 +919,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@5"
:
92.348
,
}
},
"_ops"
:
0.8
,
"_file_size"
:
27.945
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -891,6 +937,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@5"
:
93.826
,
}
},
"_ops"
:
0.8
,
"_file_size"
:
27.945
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -915,6 +963,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@5"
:
93.440
,
}
},
"_ops"
:
1.603
,
"_file_size"
:
35.339
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -931,6 +981,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@5"
:
94.922
,
}
},
"_ops"
:
1.603
,
"_file_size"
:
35.339
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -955,6 +1007,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@5"
:
93.992
,
}
},
"_ops"
:
3.177
,
"_file_size"
:
58.756
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -971,6 +1025,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@5"
:
95.430
,
}
},
"_ops"
:
3.177
,
"_file_size"
:
58.756
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -995,6 +1051,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@5"
:
94.686
,
}
},
"_ops"
:
7.995
,
"_file_size"
:
151.456
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -1011,6 +1069,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@5"
:
95.678
,
}
},
"_ops"
:
7.995
,
"_file_size"
:
151.456
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -1035,6 +1095,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@5"
:
94.944
,
}
},
"_ops"
:
15.941
,
"_file_size"
:
207.627
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -1051,6 +1113,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@5"
:
96.196
,
}
},
"_ops"
:
15.941
,
"_file_size"
:
207.627
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -1075,6 +1139,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@5"
:
95.248
,
}
},
"_ops"
:
31.736
,
"_file_size"
:
412.039
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -1091,6 +1157,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@5"
:
96.288
,
}
},
"_ops"
:
31.736
,
"_file_size"
:
412.039
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe
...
...
@@ -1501,27 +1569,3 @@ def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress:
params
=
BlockParams
.
from_init_params
(
depth
=
23
,
w_0
=
320
,
w_a
=
69.86
,
w_m
=
2.0
,
group_width
=
168
,
**
kwargs
)
return
_regnet
(
params
,
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"regnet_y_400mf"
:
RegNet_Y_400MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_800mf"
:
RegNet_Y_800MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_1_6gf"
:
RegNet_Y_1_6GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_3_2gf"
:
RegNet_Y_3_2GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_8gf"
:
RegNet_Y_8GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_16gf"
:
RegNet_Y_16GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_y_32gf"
:
RegNet_Y_32GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_400mf"
:
RegNet_X_400MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_800mf"
:
RegNet_X_800MF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_1_6gf"
:
RegNet_X_1_6GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_3_2gf"
:
RegNet_X_3_2GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_8gf"
:
RegNet_X_8GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_16gf"
:
RegNet_X_16GF_Weights
.
IMAGENET1K_V1
.
url
,
"regnet_x_32gf"
:
RegNet_X_32GF_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/resnet.py
View file @
cc26cd81
...
...
@@ -108,7 +108,7 @@ class BasicBlock(nn.Module):
class
Bottleneck
(
nn
.
Module
):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# according to "Deep residual learning for image recognition"
https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
...
...
@@ -298,7 +298,7 @@ def _resnet(
model
=
ResNet
(
block
,
layers
,
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -323,6 +323,8 @@ class ResNet18_Weights(WeightsEnum):
"acc@5"
:
89.078
,
}
},
"_ops"
:
1.814
,
"_file_size"
:
44.661
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -343,6 +345,8 @@ class ResNet34_Weights(WeightsEnum):
"acc@5"
:
91.420
,
}
},
"_ops"
:
3.664
,
"_file_size"
:
83.275
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -363,6 +367,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@5"
:
92.862
,
}
},
"_ops"
:
4.089
,
"_file_size"
:
97.781
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -379,6 +385,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@5"
:
95.434
,
}
},
"_ops"
:
4.089
,
"_file_size"
:
97.79
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -402,6 +410,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@5"
:
93.546
,
}
},
"_ops"
:
7.801
,
"_file_size"
:
170.511
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -418,6 +428,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@5"
:
95.780
,
}
},
"_ops"
:
7.801
,
"_file_size"
:
170.53
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -441,6 +453,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@5"
:
94.046
,
}
},
"_ops"
:
11.514
,
"_file_size"
:
230.434
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -457,6 +471,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@5"
:
96.002
,
}
},
"_ops"
:
11.514
,
"_file_size"
:
230.474
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -480,6 +496,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@5"
:
93.698
,
}
},
"_ops"
:
4.23
,
"_file_size"
:
95.789
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -496,6 +514,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@5"
:
95.340
,
}
},
"_ops"
:
4.23
,
"_file_size"
:
95.833
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -519,6 +539,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@5"
:
94.526
,
}
},
"_ops"
:
16.414
,
"_file_size"
:
339.586
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -535,6 +557,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@5"
:
96.228
,
}
},
"_ops"
:
16.414
,
"_file_size"
:
339.673
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -558,6 +582,8 @@ class ResNeXt101_64X4D_Weights(WeightsEnum):
"acc@5"
:
96.454
,
}
},
"_ops"
:
15.46
,
"_file_size"
:
319.318
,
"_docs"
:
"""
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -581,6 +607,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@5"
:
94.086
,
}
},
"_ops"
:
11.398
,
"_file_size"
:
131.82
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -597,6 +625,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@5"
:
95.758
,
}
},
"_ops"
:
11.398
,
"_file_size"
:
263.124
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -620,6 +650,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@5"
:
94.284
,
}
},
"_ops"
:
22.753
,
"_file_size"
:
242.896
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a simple training recipe."""
,
},
)
...
...
@@ -636,6 +668,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@5"
:
96.020
,
}
},
"_ops"
:
22.753
,
"_file_size"
:
484.747
,
"_docs"
:
"""
These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -648,7 +682,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet18_Weights
.
IMAGENET1K_V1
))
def
resnet18
(
*
,
weights
:
Optional
[
ResNet18_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
Args:
weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
...
...
@@ -674,7 +708,7 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet34_Weights
.
IMAGENET1K_V1
))
def
resnet34
(
*
,
weights
:
Optional
[
ResNet34_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
Args:
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
...
...
@@ -700,7 +734,7 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet50_Weights
.
IMAGENET1K_V1
))
def
resnet50
(
*
,
weights
:
Optional
[
ResNet50_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
.. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...
...
@@ -732,7 +766,7 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet101_Weights
.
IMAGENET1K_V1
))
def
resnet101
(
*
,
weights
:
Optional
[
ResNet101_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
.. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...
...
@@ -764,7 +798,7 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
ResNet152_Weights
.
IMAGENET1K_V1
))
def
resnet152
(
*
,
weights
:
Optional
[
ResNet152_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
ResNet
:
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
pdf
/1512.03385
.pdf
>`__.
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/
abs
/1512.03385>`__.
.. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...
...
@@ -949,22 +983,3 @@ def wide_resnet101_2(
_ovewrite_named_param
(
kwargs
,
"width_per_group"
,
64
*
2
)
return
_resnet
(
Bottleneck
,
[
3
,
4
,
23
,
3
],
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"resnet18"
:
ResNet18_Weights
.
IMAGENET1K_V1
.
url
,
"resnet34"
:
ResNet34_Weights
.
IMAGENET1K_V1
.
url
,
"resnet50"
:
ResNet50_Weights
.
IMAGENET1K_V1
.
url
,
"resnet101"
:
ResNet101_Weights
.
IMAGENET1K_V1
.
url
,
"resnet152"
:
ResNet152_Weights
.
IMAGENET1K_V1
.
url
,
"resnext50_32x4d"
:
ResNeXt50_32X4D_Weights
.
IMAGENET1K_V1
.
url
,
"resnext101_32x8d"
:
ResNeXt101_32X8D_Weights
.
IMAGENET1K_V1
.
url
,
"wide_resnet50_2"
:
Wide_ResNet50_2_Weights
.
IMAGENET1K_V1
.
url
,
"wide_resnet101_2"
:
Wide_ResNet101_2_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/segmentation/deeplabv3.py
View file @
cc26cd81
...
...
@@ -152,6 +152,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
"pixel_acc"
:
92.4
,
}
},
"_ops"
:
178.722
,
"_file_size"
:
160.515
,
},
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
...
@@ -171,6 +173,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
"pixel_acc"
:
92.4
,
}
},
"_ops"
:
258.743
,
"_file_size"
:
233.217
,
},
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
...
@@ -190,6 +194,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
"pixel_acc"
:
91.2
,
}
},
"_ops"
:
10.452
,
"_file_size"
:
42.301
,
},
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
...
@@ -269,7 +275,7 @@ def deeplabv3_resnet50(
model
=
_deeplabv3_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -325,7 +331,7 @@ def deeplabv3_resnet101(
model
=
_deeplabv3_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -379,19 +385,6 @@ def deeplabv3_mobilenet_v3_large(
model
=
_deeplabv3_mobilenetv3
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"deeplabv3_resnet50_coco"
:
DeepLabV3_ResNet50_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
"deeplabv3_resnet101_coco"
:
DeepLabV3_ResNet101_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
"deeplabv3_mobilenet_v3_large_coco"
:
DeepLabV3_MobileNet_V3_Large_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
}
)
torchvision/models/segmentation/fcn.py
View file @
cc26cd81
...
...
@@ -71,6 +71,8 @@ class FCN_ResNet50_Weights(WeightsEnum):
"pixel_acc"
:
91.4
,
}
},
"_ops"
:
152.717
,
"_file_size"
:
135.009
,
},
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
...
@@ -90,6 +92,8 @@ class FCN_ResNet101_Weights(WeightsEnum):
"pixel_acc"
:
91.9
,
}
},
"_ops"
:
232.738
,
"_file_size"
:
207.711
,
},
)
DEFAULT
=
COCO_WITH_VOC_LABELS_V1
...
...
@@ -164,7 +168,7 @@ def fcn_resnet50(
model
=
_fcn_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -223,18 +227,6 @@ def fcn_resnet101(
model
=
_fcn_resnet
(
backbone
,
num_classes
,
aux_loss
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"fcn_resnet50_coco"
:
FCN_ResNet50_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
"fcn_resnet101_coco"
:
FCN_ResNet101_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
}
)
torchvision/models/segmentation/lraspp.py
View file @
cc26cd81
...
...
@@ -108,6 +108,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
"pixel_acc"
:
91.2
,
}
},
"_ops"
:
2.086
,
"_file_size"
:
12.49
,
"_docs"
:
"""
These weights were trained on a subset of COCO, using only the 20 categories that are present in the
Pascal VOC dataset.
...
...
@@ -171,17 +173,6 @@ def lraspp_mobilenet_v3_large(
model
=
_lraspp_mobilenetv3
(
backbone
,
num_classes
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
# The dictionary below is internal implementation detail and will be removed in v0.15
from
.._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"lraspp_mobilenet_v3_large_coco"
:
LRASPP_MobileNet_V3_Large_Weights
.
COCO_WITH_VOC_LABELS_V1
.
url
,
}
)
torchvision/models/shufflenetv2.py
View file @
cc26cd81
...
...
@@ -35,7 +35,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor:
x
=
torch
.
transpose
(
x
,
1
,
2
).
contiguous
()
# flatten
x
=
x
.
view
(
batchsize
,
-
1
,
height
,
width
)
x
=
x
.
view
(
batchsize
,
num_channels
,
height
,
width
)
return
x
...
...
@@ -178,7 +178,7 @@ def _shufflenetv2(
model
=
ShuffleNetV2
(
*
args
,
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -204,6 +204,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
"acc@5"
:
81.746
,
}
},
"_ops"
:
0.04
,
"_file_size"
:
5.282
,
"_docs"
:
"""These weights were trained from scratch to reproduce closely the results of the paper."""
,
},
)
...
...
@@ -224,6 +226,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
"acc@5"
:
88.316
,
}
},
"_ops"
:
0.145
,
"_file_size"
:
8.791
,
"_docs"
:
"""These weights were trained from scratch to reproduce closely the results of the paper."""
,
},
)
...
...
@@ -244,6 +248,8 @@ class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
"acc@5"
:
91.086
,
}
},
"_ops"
:
0.296
,
"_file_size"
:
13.557
,
"_docs"
:
"""
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -267,6 +273,8 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
"acc@5"
:
93.006
,
}
},
"_ops"
:
0.583
,
"_file_size"
:
28.433
,
"_docs"
:
"""
These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...
...
@@ -398,17 +406,3 @@ def shufflenet_v2_x2_0(
weights
=
ShuffleNet_V2_X2_0_Weights
.
verify
(
weights
)
return
_shufflenetv2
(
weights
,
progress
,
[
4
,
8
,
4
],
[
24
,
244
,
488
,
976
,
2048
],
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"shufflenetv2_x0.5"
:
ShuffleNet_V2_X0_5_Weights
.
IMAGENET1K_V1
.
url
,
"shufflenetv2_x1.0"
:
ShuffleNet_V2_X1_0_Weights
.
IMAGENET1K_V1
.
url
,
"shufflenetv2_x1.5"
:
None
,
"shufflenetv2_x2.0"
:
None
,
}
)
torchvision/models/squeezenet.py
View file @
cc26cd81
...
...
@@ -109,7 +109,7 @@ def _squeezenet(
model
=
SqueezeNet
(
version
,
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -135,6 +135,8 @@ class SqueezeNet1_0_Weights(WeightsEnum):
"acc@5"
:
80.420
,
}
},
"_ops"
:
0.819
,
"_file_size"
:
4.778
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -154,6 +156,8 @@ class SqueezeNet1_1_Weights(WeightsEnum):
"acc@5"
:
80.624
,
}
},
"_ops"
:
0.349
,
"_file_size"
:
4.729
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -217,15 +221,3 @@ def squeezenet1_1(
"""
weights
=
SqueezeNet1_1_Weights
.
verify
(
weights
)
return
_squeezenet
(
"1_1"
,
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"squeezenet1_0"
:
SqueezeNet1_0_Weights
.
IMAGENET1K_V1
.
url
,
"squeezenet1_1"
:
SqueezeNet1_1_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/swin_transformer.py
View file @
cc26cd81
...
...
@@ -126,7 +126,8 @@ def shifted_window_attention(
qkv_bias
:
Optional
[
Tensor
]
=
None
,
proj_bias
:
Optional
[
Tensor
]
=
None
,
logit_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
):
training
:
bool
=
True
,
)
->
Tensor
:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
...
...
@@ -143,6 +144,7 @@ def shifted_window_attention(
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention.
"""
...
...
@@ -207,11 +209,11 @@ def shifted_window_attention(
attn
=
attn
.
view
(
-
1
,
num_heads
,
x
.
size
(
1
),
x
.
size
(
1
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
dropout
(
attn
,
p
=
attention_dropout
)
attn
=
F
.
dropout
(
attn
,
p
=
attention_dropout
,
training
=
training
)
x
=
attn
.
matmul
(
v
).
transpose
(
1
,
2
).
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
C
)
x
=
F
.
linear
(
x
,
proj_weight
,
proj_bias
)
x
=
F
.
dropout
(
x
,
p
=
dropout
)
x
=
F
.
dropout
(
x
,
p
=
dropout
,
training
=
training
)
# reverse windows
x
=
x
.
view
(
B
,
pad_H
//
window_size
[
0
],
pad_W
//
window_size
[
1
],
window_size
[
0
],
window_size
[
1
],
C
)
...
...
@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module):
self
.
relative_position_bias_table
,
self
.
relative_position_index
,
self
.
window_size
# type: ignore[arg-type]
)
def
forward
(
self
,
x
:
Tensor
):
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
"""
Args:
x (Tensor): Tensor with layout of [B, H, W, C]
...
...
@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module):
dropout
=
self
.
dropout
,
qkv_bias
=
self
.
qkv
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
training
=
self
.
training
,
)
...
...
@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention):
qkv_bias
=
self
.
qkv
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
logit_scale
=
self
.
logit_scale
,
training
=
self
.
training
,
)
...
...
@@ -494,6 +498,8 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
)
def
forward
(
self
,
x
:
Tensor
):
# Here is the difference, we apply norm after the attention in V2.
# In V1 we applied norm before the attention.
x
=
x
+
self
.
stochastic_depth
(
self
.
norm1
(
self
.
attn
(
x
)))
x
=
x
+
self
.
stochastic_depth
(
self
.
norm2
(
self
.
mlp
(
x
)))
return
x
...
...
@@ -502,7 +508,7 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
class
SwinTransformer
(
nn
.
Module
):
"""
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/
pdf
/2103.14030>`_ paper.
Shifted Windows" <https://arxiv.org/
abs
/2103.14030>`_ paper.
Args:
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
...
...
@@ -587,7 +593,7 @@ class SwinTransformer(nn.Module):
num_features
=
embed_dim
*
2
**
(
len
(
depths
)
-
1
)
self
.
norm
=
norm_layer
(
num_features
)
self
.
permute
=
Permute
([
0
,
3
,
1
,
2
])
self
.
permute
=
Permute
([
0
,
3
,
1
,
2
])
# B H W C -> B C H W
self
.
avgpool
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
flatten
=
nn
.
Flatten
(
1
)
self
.
head
=
nn
.
Linear
(
num_features
,
num_classes
)
...
...
@@ -633,7 +639,7 @@ def _swin_transformer(
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -660,6 +666,8 @@ class Swin_T_Weights(WeightsEnum):
"acc@5"
:
95.776
,
}
},
"_ops"
:
4.491
,
"_file_size"
:
108.19
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
)
...
...
@@ -683,6 +691,8 @@ class Swin_S_Weights(WeightsEnum):
"acc@5"
:
96.360
,
}
},
"_ops"
:
8.741
,
"_file_size"
:
189.786
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
)
...
...
@@ -706,6 +716,8 @@ class Swin_B_Weights(WeightsEnum):
"acc@5"
:
96.640
,
}
},
"_ops"
:
15.431
,
"_file_size"
:
335.364
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
)
...
...
@@ -729,6 +741,8 @@ class Swin_V2_T_Weights(WeightsEnum):
"acc@5"
:
96.132
,
}
},
"_ops"
:
5.94
,
"_file_size"
:
108.626
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
)
...
...
@@ -752,6 +766,8 @@ class Swin_V2_S_Weights(WeightsEnum):
"acc@5"
:
96.816
,
}
},
"_ops"
:
11.546
,
"_file_size"
:
190.675
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
)
...
...
@@ -775,6 +791,8 @@ class Swin_V2_B_Weights(WeightsEnum):
"acc@5"
:
96.864
,
}
},
"_ops"
:
20.325
,
"_file_size"
:
336.372
,
"_docs"
:
"""These weights reproduce closely the results of the paper using a similar training recipe."""
,
},
)
...
...
@@ -786,7 +804,7 @@ class Swin_V2_B_Weights(WeightsEnum):
def
swin_t
(
*
,
weights
:
Optional
[
Swin_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
Constructs a swin_tiny architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
pdf
/2103.14030>`_.
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
abs
/2103.14030>`_.
Args:
weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
...
...
@@ -824,7 +842,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
def
swin_s
(
*
,
weights
:
Optional
[
Swin_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
Constructs a swin_small architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
pdf
/2103.14030>`_.
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
abs
/2103.14030>`_.
Args:
weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
...
...
@@ -862,7 +880,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
def
swin_b
(
*
,
weights
:
Optional
[
Swin_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
Constructs a swin_base architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
pdf
/2103.14030>`_.
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/
abs
/2103.14030>`_.
Args:
weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
...
...
@@ -900,7 +918,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
def
swin_v2_t
(
*
,
weights
:
Optional
[
Swin_V2_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
Constructs a swin_v2_tiny architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
pdf
/2111.09883>`_.
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
abs
/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
...
...
@@ -940,7 +958,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T
def
swin_v2_s
(
*
,
weights
:
Optional
[
Swin_V2_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
Constructs a swin_v2_small architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
pdf
/2111.09883>`_.
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
abs
/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
...
...
@@ -980,7 +998,7 @@ def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = T
def
swin_v2_b
(
*
,
weights
:
Optional
[
Swin_V2_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer
:
"""
Constructs a swin_v2_base architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
pdf
/2111.09883>`_.
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/
abs
/2111.09883>`_.
Args:
weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
...
...
torchvision/models/vgg.py
View file @
cc26cd81
...
...
@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_ovewrite_named_param
(
kwargs
,
"num_classes"
,
len
(
weights
.
meta
[
"categories"
]))
model
=
VGG
(
make_layers
(
cfgs
[
cfg
],
batch_norm
=
batch_norm
),
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -127,6 +127,8 @@ class VGG11_Weights(WeightsEnum):
"acc@5"
:
88.628
,
}
},
"_ops"
:
7.609
,
"_file_size"
:
506.84
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -145,6 +147,8 @@ class VGG11_BN_Weights(WeightsEnum):
"acc@5"
:
89.810
,
}
},
"_ops"
:
7.609
,
"_file_size"
:
506.881
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -163,6 +167,8 @@ class VGG13_Weights(WeightsEnum):
"acc@5"
:
89.246
,
}
},
"_ops"
:
11.308
,
"_file_size"
:
507.545
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -181,6 +187,8 @@ class VGG13_BN_Weights(WeightsEnum):
"acc@5"
:
90.374
,
}
},
"_ops"
:
11.308
,
"_file_size"
:
507.59
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -199,6 +207,8 @@ class VGG16_Weights(WeightsEnum):
"acc@5"
:
90.382
,
}
},
"_ops"
:
15.47
,
"_file_size"
:
527.796
,
},
)
IMAGENET1K_FEATURES
=
Weights
(
...
...
@@ -221,6 +231,8 @@ class VGG16_Weights(WeightsEnum):
"acc@5"
:
float
(
"nan"
),
}
},
"_ops"
:
15.47
,
"_file_size"
:
527.802
,
"_docs"
:
"""
These weights can't be used for classification because they are missing values in the `classifier`
module. Only the `features` module has valid values and can be used for feature extraction. The weights
...
...
@@ -244,6 +256,8 @@ class VGG16_BN_Weights(WeightsEnum):
"acc@5"
:
91.516
,
}
},
"_ops"
:
15.47
,
"_file_size"
:
527.866
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -262,6 +276,8 @@ class VGG19_Weights(WeightsEnum):
"acc@5"
:
90.876
,
}
},
"_ops"
:
19.632
,
"_file_size"
:
548.051
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -280,6 +296,8 @@ class VGG19_BN_Weights(WeightsEnum):
"acc@5"
:
91.842
,
}
},
"_ops"
:
19.632
,
"_file_size"
:
548.143
,
},
)
DEFAULT
=
IMAGENET1K_V1
...
...
@@ -491,21 +509,3 @@ def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = Tru
weights
=
VGG19_BN_Weights
.
verify
(
weights
)
return
_vgg
(
"E"
,
True
,
weights
,
progress
,
**
kwargs
)
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"vgg11"
:
VGG11_Weights
.
IMAGENET1K_V1
.
url
,
"vgg13"
:
VGG13_Weights
.
IMAGENET1K_V1
.
url
,
"vgg16"
:
VGG16_Weights
.
IMAGENET1K_V1
.
url
,
"vgg19"
:
VGG19_Weights
.
IMAGENET1K_V1
.
url
,
"vgg11_bn"
:
VGG11_BN_Weights
.
IMAGENET1K_V1
.
url
,
"vgg13_bn"
:
VGG13_BN_Weights
.
IMAGENET1K_V1
.
url
,
"vgg16_bn"
:
VGG16_BN_Weights
.
IMAGENET1K_V1
.
url
,
"vgg19_bn"
:
VGG19_BN_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/models/video/__init__.py
View file @
cc26cd81
from
.mvit
import
*
from
.resnet
import
*
from
.s3d
import
*
from
.swin_transformer
import
*
torchvision/models/video/mvit.py
View file @
cc26cd81
...
...
@@ -593,7 +593,7 @@ def _mvit(
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -624,6 +624,8 @@ class MViT_V1_B_Weights(WeightsEnum):
"acc@5"
:
93.582
,
}
},
"_ops"
:
70.599
,
"_file_size"
:
139.764
,
},
)
DEFAULT
=
KINETICS400_V1
...
...
@@ -655,6 +657,8 @@ class MViT_V2_S_Weights(WeightsEnum):
"acc@5"
:
94.665
,
}
},
"_ops"
:
64.224
,
"_file_size"
:
131.884
,
},
)
DEFAULT
=
KINETICS400_V1
...
...
@@ -761,9 +765,10 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
MViT_V2_S_Weights
.
KINETICS400_V1
))
def
mvit_v2_s
(
*
,
weights
:
Optional
[
MViT_V2_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
MViT
:
"""
Constructs a small MViTV2 architecture from
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__.
"""Constructs a small MViTV2 architecture from
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and
`MViTv2: Improved Multiscale Vision Transformers for Classification
and Detection <https://arxiv.org/abs/2112.01526>`__.
.. betastatus:: video module
...
...
torchvision/models/video/resnet.py
View file @
cc26cd81
...
...
@@ -303,7 +303,7 @@ def _video_resnet(
model
=
VideoResNet
(
block
,
conv_makers
,
layers
,
stem
,
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -332,6 +332,8 @@ class R3D_18_Weights(WeightsEnum):
"acc@5"
:
83.479
,
}
},
"_ops"
:
40.697
,
"_file_size"
:
127.359
,
},
)
DEFAULT
=
KINETICS400_V1
...
...
@@ -350,6 +352,8 @@ class MC3_18_Weights(WeightsEnum):
"acc@5"
:
84.130
,
}
},
"_ops"
:
43.343
,
"_file_size"
:
44.672
,
},
)
DEFAULT
=
KINETICS400_V1
...
...
@@ -368,6 +372,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
"acc@5"
:
86.175
,
}
},
"_ops"
:
40.519
,
"_file_size"
:
120.318
,
},
)
DEFAULT
=
KINETICS400_V1
...
...
torchvision/models/video/s3d.py
View file @
cc26cd81
...
...
@@ -175,6 +175,8 @@ class S3D_Weights(WeightsEnum):
"acc@5"
:
88.050
,
}
},
"_ops"
:
17.979
,
"_file_size"
:
31.972
,
},
)
DEFAULT
=
KINETICS400_V1
...
...
@@ -212,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
model
=
S3D
(
**
kwargs
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
torchvision/models/video/swin_transformer.py
0 → 100644
View file @
cc26cd81
# Modified from 2d Swin Transformers in torchvision:
# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py
from
functools
import
partial
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
,
Tensor
from
...transforms._presets
import
VideoClassification
from
...utils
import
_log_api_usage_once
from
.._api
import
register_model
,
Weights
,
WeightsEnum
from
.._meta
import
_KINETICS400_CATEGORIES
from
.._utils
import
_ovewrite_named_param
,
handle_legacy_interface
from
..swin_transformer
import
PatchMerging
,
SwinTransformerBlock
__all__
=
[
"SwinTransformer3d"
,
"Swin3D_T_Weights"
,
"Swin3D_S_Weights"
,
"Swin3D_B_Weights"
,
"swin3d_t"
,
"swin3d_s"
,
"swin3d_b"
,
]
def
_get_window_and_shift_size
(
shift_size
:
List
[
int
],
size_dhw
:
List
[
int
],
window_size
:
List
[
int
]
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
for
i
in
range
(
3
):
if
size_dhw
[
i
]
<=
window_size
[
i
]:
# In this case, window_size will adapt to the input size, and no need to shift
window_size
[
i
]
=
size_dhw
[
i
]
shift_size
[
i
]
=
0
return
window_size
,
shift_size
torch
.
fx
.
wrap
(
"_get_window_and_shift_size"
)
def
_get_relative_position_bias
(
relative_position_bias_table
:
torch
.
Tensor
,
relative_position_index
:
torch
.
Tensor
,
window_size
:
List
[
int
]
)
->
Tensor
:
window_vol
=
window_size
[
0
]
*
window_size
[
1
]
*
window_size
[
2
]
# In 3d case we flatten the relative_position_bias
relative_position_bias
=
relative_position_bias_table
[
relative_position_index
[:
window_vol
,
:
window_vol
].
flatten
()
# type: ignore[index]
]
relative_position_bias
=
relative_position_bias
.
view
(
window_vol
,
window_vol
,
-
1
)
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
().
unsqueeze
(
0
)
return
relative_position_bias
torch
.
fx
.
wrap
(
"_get_relative_position_bias"
)
def
_compute_pad_size_3d
(
size_dhw
:
Tuple
[
int
,
int
,
int
],
patch_size
:
Tuple
[
int
,
int
,
int
])
->
Tuple
[
int
,
int
,
int
]:
pad_size
=
[(
patch_size
[
i
]
-
size_dhw
[
i
]
%
patch_size
[
i
])
%
patch_size
[
i
]
for
i
in
range
(
3
)]
return
pad_size
[
0
],
pad_size
[
1
],
pad_size
[
2
]
torch
.
fx
.
wrap
(
"_compute_pad_size_3d"
)
def
_compute_attention_mask_3d
(
x
:
Tensor
,
size_dhw
:
Tuple
[
int
,
int
,
int
],
window_size
:
Tuple
[
int
,
int
,
int
],
shift_size
:
Tuple
[
int
,
int
,
int
],
)
->
Tensor
:
# generate attention mask
attn_mask
=
x
.
new_zeros
(
*
size_dhw
)
num_windows
=
(
size_dhw
[
0
]
//
window_size
[
0
])
*
(
size_dhw
[
1
]
//
window_size
[
1
])
*
(
size_dhw
[
2
]
//
window_size
[
2
])
slices
=
[
(
(
0
,
-
window_size
[
i
]),
(
-
window_size
[
i
],
-
shift_size
[
i
]),
(
-
shift_size
[
i
],
None
),
)
for
i
in
range
(
3
)
]
count
=
0
for
d
in
slices
[
0
]:
for
h
in
slices
[
1
]:
for
w
in
slices
[
2
]:
attn_mask
[
d
[
0
]
:
d
[
1
],
h
[
0
]
:
h
[
1
],
w
[
0
]
:
w
[
1
]]
=
count
count
+=
1
# Partition window on attn_mask
attn_mask
=
attn_mask
.
view
(
size_dhw
[
0
]
//
window_size
[
0
],
window_size
[
0
],
size_dhw
[
1
]
//
window_size
[
1
],
window_size
[
1
],
size_dhw
[
2
]
//
window_size
[
2
],
window_size
[
2
],
)
attn_mask
=
attn_mask
.
permute
(
0
,
2
,
4
,
1
,
3
,
5
).
reshape
(
num_windows
,
window_size
[
0
]
*
window_size
[
1
]
*
window_size
[
2
]
)
attn_mask
=
attn_mask
.
unsqueeze
(
1
)
-
attn_mask
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
torch
.
fx
.
wrap
(
"_compute_attention_mask_3d"
)
def
shifted_window_attention_3d
(
input
:
Tensor
,
qkv_weight
:
Tensor
,
proj_weight
:
Tensor
,
relative_position_bias
:
Tensor
,
window_size
:
List
[
int
],
num_heads
:
int
,
shift_size
:
List
[
int
],
attention_dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
qkv_bias
:
Optional
[
Tensor
]
=
None
,
proj_bias
:
Optional
[
Tensor
]
=
None
,
training
:
bool
=
True
,
)
->
Tensor
:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
input (Tensor[B, T, H, W, C]): The input tensor, 5-dimensions.
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (List[int]): 3-dimensions window size, T, H, W .
num_heads (int): Number of attention heads.
shift_size (List[int]): Shift size for shifted window attention (T, H, W).
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Tensor[B, T, H, W, C]: The output tensor after shifted window attention.
"""
b
,
t
,
h
,
w
,
c
=
input
.
shape
# pad feature maps to multiples of window size
pad_size
=
_compute_pad_size_3d
((
t
,
h
,
w
),
(
window_size
[
0
],
window_size
[
1
],
window_size
[
2
]))
x
=
F
.
pad
(
input
,
(
0
,
0
,
0
,
pad_size
[
2
],
0
,
pad_size
[
1
],
0
,
pad_size
[
0
]))
_
,
tp
,
hp
,
wp
,
_
=
x
.
shape
padded_size
=
(
tp
,
hp
,
wp
)
# cyclic shift
if
sum
(
shift_size
)
>
0
:
x
=
torch
.
roll
(
x
,
shifts
=
(
-
shift_size
[
0
],
-
shift_size
[
1
],
-
shift_size
[
2
]),
dims
=
(
1
,
2
,
3
))
# partition windows
num_windows
=
(
(
padded_size
[
0
]
//
window_size
[
0
])
*
(
padded_size
[
1
]
//
window_size
[
1
])
*
(
padded_size
[
2
]
//
window_size
[
2
])
)
x
=
x
.
view
(
b
,
padded_size
[
0
]
//
window_size
[
0
],
window_size
[
0
],
padded_size
[
1
]
//
window_size
[
1
],
window_size
[
1
],
padded_size
[
2
]
//
window_size
[
2
],
window_size
[
2
],
c
,
)
x
=
x
.
permute
(
0
,
1
,
3
,
5
,
2
,
4
,
6
,
7
).
reshape
(
b
*
num_windows
,
window_size
[
0
]
*
window_size
[
1
]
*
window_size
[
2
],
c
)
# B*nW, Wd*Wh*Ww, C
# multi-head attention
qkv
=
F
.
linear
(
x
,
qkv_weight
,
qkv_bias
)
qkv
=
qkv
.
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
3
,
num_heads
,
c
//
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
(
c
//
num_heads
)
**
-
0.5
attn
=
q
.
matmul
(
k
.
transpose
(
-
2
,
-
1
))
# add relative position bias
attn
=
attn
+
relative_position_bias
if
sum
(
shift_size
)
>
0
:
# generate attention mask to handle shifted windows with varying size
attn_mask
=
_compute_attention_mask_3d
(
x
,
(
padded_size
[
0
],
padded_size
[
1
],
padded_size
[
2
]),
(
window_size
[
0
],
window_size
[
1
],
window_size
[
2
]),
(
shift_size
[
0
],
shift_size
[
1
],
shift_size
[
2
]),
)
attn
=
attn
.
view
(
x
.
size
(
0
)
//
num_windows
,
num_windows
,
num_heads
,
x
.
size
(
1
),
x
.
size
(
1
))
attn
=
attn
+
attn_mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
num_heads
,
x
.
size
(
1
),
x
.
size
(
1
))
attn
=
F
.
softmax
(
attn
,
dim
=-
1
)
attn
=
F
.
dropout
(
attn
,
p
=
attention_dropout
,
training
=
training
)
x
=
attn
.
matmul
(
v
).
transpose
(
1
,
2
).
reshape
(
x
.
size
(
0
),
x
.
size
(
1
),
c
)
x
=
F
.
linear
(
x
,
proj_weight
,
proj_bias
)
x
=
F
.
dropout
(
x
,
p
=
dropout
,
training
=
training
)
# reverse windows
x
=
x
.
view
(
b
,
padded_size
[
0
]
//
window_size
[
0
],
padded_size
[
1
]
//
window_size
[
1
],
padded_size
[
2
]
//
window_size
[
2
],
window_size
[
0
],
window_size
[
1
],
window_size
[
2
],
c
,
)
x
=
x
.
permute
(
0
,
1
,
4
,
2
,
5
,
3
,
6
,
7
).
reshape
(
b
,
tp
,
hp
,
wp
,
c
)
# reverse cyclic shift
if
sum
(
shift_size
)
>
0
:
x
=
torch
.
roll
(
x
,
shifts
=
(
shift_size
[
0
],
shift_size
[
1
],
shift_size
[
2
]),
dims
=
(
1
,
2
,
3
))
# unpad features
x
=
x
[:,
:
t
,
:
h
,
:
w
,
:].
contiguous
()
return
x
torch
.
fx
.
wrap
(
"shifted_window_attention_3d"
)
class
ShiftedWindowAttention3d
(
nn
.
Module
):
"""
See :func:`shifted_window_attention_3d`.
"""
def
__init__
(
self
,
dim
:
int
,
window_size
:
List
[
int
],
shift_size
:
List
[
int
],
num_heads
:
int
,
qkv_bias
:
bool
=
True
,
proj_bias
:
bool
=
True
,
attention_dropout
:
float
=
0.0
,
dropout
:
float
=
0.0
,
)
->
None
:
super
().
__init__
()
if
len
(
window_size
)
!=
3
or
len
(
shift_size
)
!=
3
:
raise
ValueError
(
"window_size and shift_size must be of length 2"
)
self
.
window_size
=
window_size
# Wd, Wh, Ww
self
.
shift_size
=
shift_size
self
.
num_heads
=
num_heads
self
.
attention_dropout
=
attention_dropout
self
.
dropout
=
dropout
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
,
bias
=
proj_bias
)
self
.
define_relative_position_bias_table
()
self
.
define_relative_position_index
()
def
define_relative_position_bias_table
(
self
)
->
None
:
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
(
2
*
self
.
window_size
[
0
]
-
1
)
*
(
2
*
self
.
window_size
[
1
]
-
1
)
*
(
2
*
self
.
window_size
[
2
]
-
1
),
self
.
num_heads
,
)
)
# 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
nn
.
init
.
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
0.02
)
def
define_relative_position_index
(
self
)
->
None
:
# get pair-wise relative position index for each token inside the window
coords_dhw
=
[
torch
.
arange
(
self
.
window_size
[
i
])
for
i
in
range
(
3
)]
coords
=
torch
.
stack
(
torch
.
meshgrid
(
coords_dhw
[
0
],
coords_dhw
[
1
],
coords_dhw
[
2
],
indexing
=
"ij"
)
)
# 3, Wd, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 3, Wd*Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 3, Wd*Wh*Ww, Wd*Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wd*Wh*Ww, Wd*Wh*Ww, 3
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
2
]
+=
self
.
window_size
[
2
]
-
1
relative_coords
[:,
:,
0
]
*=
(
2
*
self
.
window_size
[
1
]
-
1
)
*
(
2
*
self
.
window_size
[
2
]
-
1
)
relative_coords
[:,
:,
1
]
*=
2
*
self
.
window_size
[
2
]
-
1
# We don't flatten the relative_position_index here in 3d case.
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wd*Wh*Ww, Wd*Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
def
get_relative_position_bias
(
self
,
window_size
:
List
[
int
])
->
torch
.
Tensor
:
return
_get_relative_position_bias
(
self
.
relative_position_bias_table
,
self
.
relative_position_index
,
window_size
)
# type: ignore
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
_
,
t
,
h
,
w
,
_
=
x
.
shape
size_dhw
=
[
t
,
h
,
w
]
window_size
,
shift_size
=
self
.
window_size
.
copy
(),
self
.
shift_size
.
copy
()
# Handle case where window_size is larger than the input tensor
window_size
,
shift_size
=
_get_window_and_shift_size
(
shift_size
,
size_dhw
,
window_size
)
relative_position_bias
=
self
.
get_relative_position_bias
(
window_size
)
return
shifted_window_attention_3d
(
x
,
self
.
qkv
.
weight
,
self
.
proj
.
weight
,
relative_position_bias
,
window_size
,
self
.
num_heads
,
shift_size
=
shift_size
,
attention_dropout
=
self
.
attention_dropout
,
dropout
=
self
.
dropout
,
qkv_bias
=
self
.
qkv
.
bias
,
proj_bias
=
self
.
proj
.
bias
,
training
=
self
.
training
,
)
# Modified from:
# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py
class
PatchEmbed3d
(
nn
.
Module
):
"""Video to Patch Embedding.
Args:
patch_size (List[int]): Patch token size.
in_channels (int): Number of input channels. Default: 3
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
patch_size
:
List
[
int
],
in_channels
:
int
=
3
,
embed_dim
:
int
=
96
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
super
().
__init__
()
_log_api_usage_once
(
self
)
self
.
tuple_patch_size
=
(
patch_size
[
0
],
patch_size
[
1
],
patch_size
[
2
])
self
.
proj
=
nn
.
Conv3d
(
in_channels
,
embed_dim
,
kernel_size
=
self
.
tuple_patch_size
,
stride
=
self
.
tuple_patch_size
,
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
nn
.
Identity
()
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
"""Forward function."""
# padding
_
,
_
,
t
,
h
,
w
=
x
.
size
()
pad_size
=
_compute_pad_size_3d
((
t
,
h
,
w
),
self
.
tuple_patch_size
)
x
=
F
.
pad
(
x
,
(
0
,
pad_size
[
2
],
0
,
pad_size
[
1
],
0
,
pad_size
[
0
]))
x
=
self
.
proj
(
x
)
# B C T Wh Ww
x
=
x
.
permute
(
0
,
2
,
3
,
4
,
1
)
# B T Wh Ww C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
class
SwinTransformer3d
(
nn
.
Module
):
"""
Implements 3D Swin Transformer from the `"Video Swin Transformer" <https://arxiv.org/abs/2106.13230>`_ paper.
Args:
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
num_classes (int): Number of classes for classification head. Default: 400.
norm_layer (nn.Module, optional): Normalization layer. Default: None.
block (nn.Module, optional): SwinTransformer Block. Default: None.
downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
patch_embed (nn.Module, optional): Patch Embedding layer. Default: None.
"""
def
__init__
(
self
,
patch_size
:
List
[
int
],
embed_dim
:
int
,
depths
:
List
[
int
],
num_heads
:
List
[
int
],
window_size
:
List
[
int
],
mlp_ratio
:
float
=
4.0
,
dropout
:
float
=
0.0
,
attention_dropout
:
float
=
0.0
,
stochastic_depth_prob
:
float
=
0.1
,
num_classes
:
int
=
400
,
norm_layer
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
block
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
downsample_layer
:
Callable
[...,
nn
.
Module
]
=
PatchMerging
,
patch_embed
:
Optional
[
Callable
[...,
nn
.
Module
]]
=
None
,
)
->
None
:
super
().
__init__
()
_log_api_usage_once
(
self
)
self
.
num_classes
=
num_classes
if
block
is
None
:
block
=
partial
(
SwinTransformerBlock
,
attn_layer
=
ShiftedWindowAttention3d
)
if
norm_layer
is
None
:
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-5
)
if
patch_embed
is
None
:
patch_embed
=
PatchEmbed3d
# split image into non-overlapping patches
self
.
patch_embed
=
patch_embed
(
patch_size
=
patch_size
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
dropout
)
layers
:
List
[
nn
.
Module
]
=
[]
total_stage_blocks
=
sum
(
depths
)
stage_block_id
=
0
# build SwinTransformer blocks
for
i_stage
in
range
(
len
(
depths
)):
stage
:
List
[
nn
.
Module
]
=
[]
dim
=
embed_dim
*
2
**
i_stage
for
i_layer
in
range
(
depths
[
i_stage
]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob
=
stochastic_depth_prob
*
float
(
stage_block_id
)
/
(
total_stage_blocks
-
1
)
stage
.
append
(
block
(
dim
,
num_heads
[
i_stage
],
window_size
=
window_size
,
shift_size
=
[
0
if
i_layer
%
2
==
0
else
w
//
2
for
w
in
window_size
],
mlp_ratio
=
mlp_ratio
,
dropout
=
dropout
,
attention_dropout
=
attention_dropout
,
stochastic_depth_prob
=
sd_prob
,
norm_layer
=
norm_layer
,
attn_layer
=
ShiftedWindowAttention3d
,
)
)
stage_block_id
+=
1
layers
.
append
(
nn
.
Sequential
(
*
stage
))
# add patch merging layer
if
i_stage
<
(
len
(
depths
)
-
1
):
layers
.
append
(
downsample_layer
(
dim
,
norm_layer
))
self
.
features
=
nn
.
Sequential
(
*
layers
)
self
.
num_features
=
embed_dim
*
2
**
(
len
(
depths
)
-
1
)
self
.
norm
=
norm_layer
(
self
.
num_features
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool3d
(
1
)
self
.
head
=
nn
.
Linear
(
self
.
num_features
,
num_classes
)
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
m
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
m
.
bias
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# x: B C T H W
x
=
self
.
patch_embed
(
x
)
# B _T _H _W C
x
=
self
.
pos_drop
(
x
)
x
=
self
.
features
(
x
)
# B _T _H _W C
x
=
self
.
norm
(
x
)
x
=
x
.
permute
(
0
,
4
,
1
,
2
,
3
)
# B, C, _T, _H, _W
x
=
self
.
avgpool
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
head
(
x
)
return
x
def
_swin_transformer3d
(
patch_size
:
List
[
int
],
embed_dim
:
int
,
depths
:
List
[
int
],
num_heads
:
List
[
int
],
window_size
:
List
[
int
],
stochastic_depth_prob
:
float
,
weights
:
Optional
[
WeightsEnum
],
progress
:
bool
,
**
kwargs
:
Any
,
)
->
SwinTransformer3d
:
if
weights
is
not
None
:
_ovewrite_named_param
(
kwargs
,
"num_classes"
,
len
(
weights
.
meta
[
"categories"
]))
model
=
SwinTransformer3d
(
patch_size
=
patch_size
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
window_size
,
stochastic_depth_prob
=
stochastic_depth_prob
,
**
kwargs
,
)
if
weights
is
not
None
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
_COMMON_META
=
{
"categories"
:
_KINETICS400_CATEGORIES
,
"min_size"
:
(
1
,
1
),
"min_temporal_size"
:
1
,
}
class
Swin3D_T_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_t-7615ae03.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
28158070
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
77.715
,
"acc@5"
:
93.519
,
}
},
"_ops"
:
43.882
,
"_file_size"
:
121.543
,
},
)
DEFAULT
=
KINETICS400_V1
class
Swin3D_S_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_s-da41c237.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
49816678
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
79.521
,
"acc@5"
:
94.158
,
}
},
"_ops"
:
82.841
,
"_file_size"
:
218.288
,
},
)
DEFAULT
=
KINETICS400_V1
class
Swin3D_B_Weights
(
WeightsEnum
):
KINETICS400_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
88048984
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
79.427
,
"acc@5"
:
94.386
,
}
},
"_ops"
:
140.667
,
"_file_size"
:
364.134
,
},
)
KINETICS400_IMAGENET22K_V1
=
Weights
(
url
=
"https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth"
,
transforms
=
partial
(
VideoClassification
,
crop_size
=
(
224
,
224
),
resize_size
=
(
256
,),
mean
=
(
0.4850
,
0.4560
,
0.4060
),
std
=
(
0.2290
,
0.2240
,
0.2250
),
),
meta
=
{
**
_COMMON_META
,
"recipe"
:
"https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400"
,
"_docs"
:
(
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params"
:
88048984
,
"_metrics"
:
{
"Kinetics-400"
:
{
"acc@1"
:
81.643
,
"acc@5"
:
95.574
,
}
},
"_ops"
:
140.667
,
"_file_size"
:
364.134
,
},
)
DEFAULT
=
KINETICS400_V1
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
Swin3D_T_Weights
.
KINETICS400_V1
))
def
swin3d_t
(
*
,
weights
:
Optional
[
Swin3D_T_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer3d
:
"""
Constructs a swin_tiny architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_T_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_T_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_T_Weights
:members:
"""
weights
=
Swin3D_T_Weights
.
verify
(
weights
)
return
_swin_transformer3d
(
patch_size
=
[
2
,
4
,
4
],
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
[
8
,
7
,
7
],
stochastic_depth_prob
=
0.1
,
weights
=
weights
,
progress
=
progress
,
**
kwargs
,
)
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
Swin3D_S_Weights
.
KINETICS400_V1
))
def
swin3d_s
(
*
,
weights
:
Optional
[
Swin3D_S_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer3d
:
"""
Constructs a swin_small architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_S_Weights
:members:
"""
weights
=
Swin3D_S_Weights
.
verify
(
weights
)
return
_swin_transformer3d
(
patch_size
=
[
2
,
4
,
4
],
embed_dim
=
96
,
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
[
8
,
7
,
7
],
stochastic_depth_prob
=
0.1
,
weights
=
weights
,
progress
=
progress
,
**
kwargs
,
)
@
register_model
()
@
handle_legacy_interface
(
weights
=
(
"pretrained"
,
Swin3D_B_Weights
.
KINETICS400_V1
))
def
swin3d_b
(
*
,
weights
:
Optional
[
Swin3D_B_Weights
]
=
None
,
progress
:
bool
=
True
,
**
kwargs
:
Any
)
->
SwinTransformer3d
:
"""
Constructs a swin_base architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_B_Weights
:members:
"""
weights
=
Swin3D_B_Weights
.
verify
(
weights
)
return
_swin_transformer3d
(
patch_size
=
[
2
,
4
,
4
],
embed_dim
=
128
,
depths
=
[
2
,
2
,
18
,
2
],
num_heads
=
[
4
,
8
,
16
,
32
],
window_size
=
[
8
,
7
,
7
],
stochastic_depth_prob
=
0.1
,
weights
=
weights
,
progress
=
progress
,
**
kwargs
,
)
torchvision/models/vision_transformer.py
View file @
cc26cd81
...
...
@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module):
def
forward
(
self
,
input
:
torch
.
Tensor
):
torch
.
_assert
(
input
.
dim
()
==
3
,
f
"Expected (batch_size, seq_length, hidden_dim) got
{
input
.
shape
}
"
)
x
=
self
.
ln_1
(
input
)
x
,
_
=
self
.
self_attention
(
query
=
x
,
key
=
x
,
value
=
x
,
need_weights
=
False
)
x
,
_
=
self
.
self_attention
(
x
,
x
,
x
,
need_weights
=
False
)
x
=
self
.
dropout
(
x
)
x
=
x
+
input
...
...
@@ -332,7 +332,7 @@ def _vision_transformer(
)
if
weights
:
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
))
model
.
load_state_dict
(
weights
.
get_state_dict
(
progress
=
progress
,
check_hash
=
True
))
return
model
...
...
@@ -363,6 +363,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5"
:
95.318
,
}
},
"_ops"
:
17.564
,
"_file_size"
:
330.285
,
"_docs"
:
"""
These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
...
...
@@ -387,6 +389,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5"
:
97.650
,
}
},
"_ops"
:
55.484
,
"_file_size"
:
331.398
,
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
...
@@ -412,6 +416,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5"
:
96.180
,
}
},
"_ops"
:
17.564
,
"_file_size"
:
330.285
,
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
...
@@ -436,6 +442,8 @@ class ViT_B_32_Weights(WeightsEnum):
"acc@5"
:
92.466
,
}
},
"_ops"
:
4.409
,
"_file_size"
:
336.604
,
"_docs"
:
"""
These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
...
...
@@ -460,6 +468,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5"
:
94.638
,
}
},
"_ops"
:
61.555
,
"_file_size"
:
1161.023
,
"_docs"
:
"""
These weights were trained from scratch by using a modified version of TorchVision's
`new training recipe
...
...
@@ -485,6 +495,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5"
:
98.512
,
}
},
"_ops"
:
361.986
,
"_file_size"
:
1164.258
,
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
...
@@ -510,6 +522,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5"
:
97.422
,
}
},
"_ops"
:
61.555
,
"_file_size"
:
1161.023
,
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
...
@@ -534,6 +548,8 @@ class ViT_L_32_Weights(WeightsEnum):
"acc@5"
:
93.07
,
}
},
"_ops"
:
15.378
,
"_file_size"
:
1169.449
,
"_docs"
:
"""
These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe.
...
...
@@ -562,6 +578,8 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5"
:
98.694
,
}
},
"_ops"
:
1016.717
,
"_file_size"
:
2416.643
,
"_docs"
:
"""
These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...
...
@@ -587,6 +605,8 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5"
:
97.730
,
}
},
"_ops"
:
167.295
,
"_file_size"
:
2411.209
,
"_docs"
:
"""
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...
...
@@ -773,7 +793,7 @@ def interpolate_embeddings(
interpolation_mode
:
str
=
"bicubic"
,
reset_heads
:
bool
=
False
,
)
->
"OrderedDict[str, torch.Tensor]"
:
"""This function helps interpolat
ing
positional embeddings during checkpoint loading,
"""This function helps interpolat
e
positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution.
Args:
...
...
@@ -798,7 +818,7 @@ def interpolate_embeddings(
# We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if
new_seq_length
!=
seq_length
:
# The class token embedding shouldn't be interpolated so we split it up.
# The class token embedding shouldn't be interpolated
,
so we split it up.
seq_length
-=
1
new_seq_length
-=
1
pos_embedding_token
=
pos_embedding
[:,
:
1
,
:]
...
...
@@ -842,17 +862,3 @@ def interpolate_embeddings(
model_state
=
model_state_copy
return
model_state
# The dictionary below is internal implementation detail and will be removed in v0.15
from
._utils
import
_ModelURLs
model_urls
=
_ModelURLs
(
{
"vit_b_16"
:
ViT_B_16_Weights
.
IMAGENET1K_V1
.
url
,
"vit_b_32"
:
ViT_B_32_Weights
.
IMAGENET1K_V1
.
url
,
"vit_l_16"
:
ViT_L_16_Weights
.
IMAGENET1K_V1
.
url
,
"vit_l_32"
:
ViT_L_32_Weights
.
IMAGENET1K_V1
.
url
,
}
)
torchvision/ops/_box_convert.py
View file @
cc26cd81
...
...
@@ -50,7 +50,7 @@ def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
def
_box_xywh_to_xyxy
(
boxes
:
Tensor
)
->
Tensor
:
"""
Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format.
(x, y) refers to top left of bouding box.
(x, y) refers to top left of bou
n
ding box.
(w, h) refers to width and height of box.
Args:
boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted.
...
...
torchvision/ops/_register_onnx_ops.py
View file @
cc26cd81
...
...
@@ -2,45 +2,65 @@ import sys
import
warnings
import
torch
from
torch.onnx
import
symbolic_opset11
as
opset11
from
torch.onnx.symbolic_helper
import
parse_args
_onnx_opset_version
=
11
_ONNX_OPSET_VERSION_11
=
11
_ONNX_OPSET_VERSION_16
=
16
BASE_ONNX_OPSET_VERSION
=
_ONNX_OPSET_VERSION_11
def
_register_custom_op
():
from
torch.onnx.symbolic_helper
import
parse_args
from
torch.onnx.symbolic_opset11
import
select
,
squeeze
,
unsqueeze
from
torch.onnx.symbolic_opset9
import
_cast_Long
@
parse_args
(
"v"
,
"v"
,
"f"
)
def
symbolic_multi_label_nms
(
g
,
boxes
,
scores
,
iou_threshold
):
boxes
=
unsqueeze
(
g
,
boxes
,
0
)
scores
=
unsqueeze
(
g
,
unsqueeze
(
g
,
scores
,
0
),
0
)
@
parse_args
(
"v"
,
"v"
,
"f"
)
def
symbolic_multi_label_nms
(
g
,
boxes
,
scores
,
iou_threshold
):
boxes
=
opset11
.
unsqueeze
(
g
,
boxes
,
0
)
scores
=
opset11
.
unsqueeze
(
g
,
opset11
.
unsqueeze
(
g
,
scores
,
0
),
0
)
max_output_per_class
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
sys
.
maxsize
],
dtype
=
torch
.
long
))
iou_threshold
=
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
iou_threshold
],
dtype
=
torch
.
float
))
nms_out
=
g
.
op
(
"NonMaxSuppression"
,
boxes
,
scores
,
max_output_per_class
,
iou_threshold
)
return
squeeze
(
g
,
select
(
g
,
nms_out
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
))),
1
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
batch_indices
=
_cast_Long
(
g
,
squeeze
(
g
,
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
),
False
# Cast boxes and scores to float32 in case they are float64 inputs
nms_out
=
g
.
op
(
"NonMaxSuppression"
,
g
.
op
(
"Cast"
,
boxes
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
FLOAT
),
g
.
op
(
"Cast"
,
scores
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
FLOAT
),
max_output_per_class
,
iou_threshold
,
)
rois
=
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
# TODO: Remove this warning after ONNX opset 16 is supported.
if
aligned
:
warnings
.
warn
(
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
"The workaround is that the user need apply the patch "
"https://github.com/microsoft/onnxruntime/pull/8564 "
"and build ONNXRuntime from source."
return
opset11
.
squeeze
(
g
,
opset11
.
select
(
g
,
nms_out
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
2
],
dtype
=
torch
.
long
))),
1
)
# ONNX doesn't support negative sampling_ratio
def
_process_batch_indices_for_roi_align
(
g
,
rois
):
indices
=
opset11
.
squeeze
(
g
,
opset11
.
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
))),
1
)
return
g
.
op
(
"Cast"
,
indices
,
to_i
=
torch
.
onnx
.
TensorProtoDataType
.
INT64
)
def
_process_rois_for_roi_align
(
g
,
rois
):
return
opset11
.
select
(
g
,
rois
,
1
,
g
.
op
(
"Constant"
,
value_t
=
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
long
)))
def
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
:
int
):
if
sampling_ratio
<
0
:
warnings
.
warn
(
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported."
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
"The model will be exported with a sampling_ratio of 0."
)
sampling_ratio
=
0
return
sampling_ratio
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align_opset11
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
batch_indices
=
_process_batch_indices_for_roi_align
(
g
,
rois
)
rois
=
_process_rois_for_roi_align
(
g
,
rois
)
if
aligned
:
warnings
.
warn
(
"ROIAlign with aligned=True is only supported in opset >= 16. "
"Please export with opset 16 or higher, or use aligned=False."
)
sampling_ratio
=
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
)
return
g
.
op
(
"RoiAlign"
,
input
,
...
...
@@ -52,15 +72,36 @@ def _register_custom_op():
sampling_ratio_i
=
sampling_ratio
,
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
)
def
roi_pool
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
):
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
,
"i"
,
"i"
)
def
roi_align_opset16
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
,
sampling_ratio
,
aligned
):
batch_indices
=
_process_batch_indices_for_roi_align
(
g
,
rois
)
rois
=
_process_rois_for_roi_align
(
g
,
rois
)
coordinate_transformation_mode
=
"half_pixel"
if
aligned
else
"output_half_pixel"
sampling_ratio
=
_process_sampling_ratio_for_roi_align
(
g
,
sampling_ratio
)
return
g
.
op
(
"RoiAlign"
,
input
,
rois
,
batch_indices
,
coordinate_transformation_mode_s
=
coordinate_transformation_mode
,
spatial_scale_f
=
spatial_scale
,
output_height_i
=
pooled_height
,
output_width_i
=
pooled_width
,
sampling_ratio_i
=
sampling_ratio
,
)
@
parse_args
(
"v"
,
"v"
,
"f"
,
"i"
,
"i"
)
def
roi_pool
(
g
,
input
,
rois
,
spatial_scale
,
pooled_height
,
pooled_width
):
roi_pool
=
g
.
op
(
"MaxRoiPool"
,
input
,
rois
,
pooled_shape_i
=
(
pooled_height
,
pooled_width
),
spatial_scale_f
=
spatial_scale
)
return
roi_pool
,
None
from
torch.onnx
import
register_custom_op_symbolic
register_custom_op_symbolic
(
"torchvision::nms"
,
symbolic_multi_label_nms
,
_onnx_opset_version
)
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align
,
_onnx_opset_version
)
register_custom_op_symbolic
(
"torchvision::roi_pool"
,
roi_pool
,
_onnx_opset_version
)
def
_register_custom_op
():
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::nms"
,
symbolic_multi_label_nms
,
_ONNX_OPSET_VERSION_11
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset11
,
_ONNX_OPSET_VERSION_11
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_align"
,
roi_align_opset16
,
_ONNX_OPSET_VERSION_16
)
torch
.
onnx
.
register_custom_op_symbolic
(
"torchvision::roi_pool"
,
roi_pool
,
_ONNX_OPSET_VERSION_11
)
torchvision/ops/ciou_loss.py
View file @
cc26cd81
...
...
@@ -63,9 +63,16 @@ def complete_box_iou_loss(
alpha
=
v
/
(
1
-
iou
+
v
+
eps
)
loss
=
diou_loss
+
alpha
*
v
if
reduction
==
"mean"
:
# Check reduction option and return loss accordingly
if
reduction
==
"none"
:
pass
elif
reduction
==
"mean"
:
loss
=
loss
.
mean
()
if
loss
.
numel
()
>
0
else
0.0
*
loss
.
sum
()
elif
reduction
==
"sum"
:
loss
=
loss
.
sum
()
else
:
raise
ValueError
(
f
"Invalid Value for arg 'reduction': '
{
reduction
}
\n
Supported reduction modes: 'none', 'mean', 'sum'"
)
return
loss
torchvision/ops/deform_conv.py
View file @
cc26cd81
...
...
@@ -68,7 +68,7 @@ def deform_conv2d(
use_mask
=
mask
is
not
None
if
mask
is
None
:
mask
=
torch
.
zeros
((
input
.
shape
[
0
],
0
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
mask
=
torch
.
zeros
((
input
.
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
)
if
bias
is
None
:
bias
=
torch
.
zeros
(
out_channels
,
device
=
input
.
device
,
dtype
=
input
.
dtype
)
...
...
Prev
1
…
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment