Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
...@@ -108,7 +108,7 @@ def _shufflenetv2( ...@@ -108,7 +108,7 @@ def _shufflenetv2(
quantize_model(model, backend) quantize_model(model, backend)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -139,6 +139,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): ...@@ -139,6 +139,8 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
"acc@5": 79.780, "acc@5": 79.780,
} }
}, },
"_ops": 0.04,
"_file_size": 1.501,
}, },
) )
DEFAULT = IMAGENET1K_FBGEMM_V1 DEFAULT = IMAGENET1K_FBGEMM_V1
...@@ -146,7 +148,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): ...@@ -146,7 +148,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-1e62bb32.pth",
transforms=partial(ImageClassification, crop_size=224), transforms=partial(ImageClassification, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
...@@ -158,6 +160,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): ...@@ -158,6 +160,8 @@ class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
"acc@5": 87.582, "acc@5": 87.582,
} }
}, },
"_ops": 0.145,
"_file_size": 2.334,
}, },
) )
DEFAULT = IMAGENET1K_FBGEMM_V1 DEFAULT = IMAGENET1K_FBGEMM_V1
...@@ -178,6 +182,8 @@ class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum): ...@@ -178,6 +182,8 @@ class ShuffleNet_V2_X1_5_QuantizedWeights(WeightsEnum):
"acc@5": 90.700, "acc@5": 90.700,
} }
}, },
"_ops": 0.296,
"_file_size": 3.672,
}, },
) )
DEFAULT = IMAGENET1K_FBGEMM_V1 DEFAULT = IMAGENET1K_FBGEMM_V1
...@@ -198,6 +204,8 @@ class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum): ...@@ -198,6 +204,8 @@ class ShuffleNet_V2_X2_0_QuantizedWeights(WeightsEnum):
"acc@5": 92.488, "acc@5": 92.488,
} }
}, },
"_ops": 0.583,
"_file_size": 7.467,
}, },
) )
DEFAULT = IMAGENET1K_FBGEMM_V1 DEFAULT = IMAGENET1K_FBGEMM_V1
...@@ -417,16 +425,3 @@ def shufflenet_v2_x2_0( ...@@ -417,16 +425,3 @@ def shufflenet_v2_x2_0(
return _shufflenetv2( return _shufflenetv2(
[4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs [4, 8, 4], [24, 244, 488, 976, 2048], weights=weights, progress=progress, quantize=quantize, **kwargs
) )
# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
from ..shufflenetv2 import model_urls # noqa: F401
quant_model_urls = _ModelURLs(
{
"shufflenetv2_x0.5_fbgemm": ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
"shufflenetv2_x1.0_fbgemm": ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1.url,
}
)
...@@ -212,7 +212,7 @@ class BlockParams: ...@@ -212,7 +212,7 @@ class BlockParams:
**kwargs: Any, **kwargs: Any,
) -> "BlockParams": ) -> "BlockParams":
""" """
Programatically compute all the per-block settings, Programmatically compute all the per-block settings,
given the RegNet parameters. given the RegNet parameters.
The first step is to compute the quantized linear block parameters, The first step is to compute the quantized linear block parameters,
...@@ -397,7 +397,7 @@ def _regnet( ...@@ -397,7 +397,7 @@ def _regnet(
model = RegNet(block_params, norm_layer=norm_layer, **kwargs) model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -428,6 +428,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ...@@ -428,6 +428,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@5": 91.716, "acc@5": 91.716,
} }
}, },
"_ops": 0.402,
"_file_size": 16.806,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -444,6 +446,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ...@@ -444,6 +446,8 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
"acc@5": 92.742, "acc@5": 92.742,
} }
}, },
"_ops": 0.402,
"_file_size": 16.806,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -468,6 +472,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ...@@ -468,6 +472,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@5": 93.136, "acc@5": 93.136,
} }
}, },
"_ops": 0.834,
"_file_size": 24.774,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -484,6 +490,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ...@@ -484,6 +490,8 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
"acc@5": 94.502, "acc@5": 94.502,
} }
}, },
"_ops": 0.834,
"_file_size": 24.774,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -508,6 +516,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ...@@ -508,6 +516,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@5": 93.966, "acc@5": 93.966,
} }
}, },
"_ops": 1.612,
"_file_size": 43.152,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -524,6 +534,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ...@@ -524,6 +534,8 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
"acc@5": 95.444, "acc@5": 95.444,
} }
}, },
"_ops": 1.612,
"_file_size": 43.152,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -548,6 +560,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ...@@ -548,6 +560,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@5": 94.576, "acc@5": 94.576,
} }
}, },
"_ops": 3.176,
"_file_size": 74.567,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -564,6 +578,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ...@@ -564,6 +578,8 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
"acc@5": 95.972, "acc@5": 95.972,
} }
}, },
"_ops": 3.176,
"_file_size": 74.567,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -588,6 +604,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ...@@ -588,6 +604,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@5": 95.048, "acc@5": 95.048,
} }
}, },
"_ops": 8.473,
"_file_size": 150.701,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -604,6 +622,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ...@@ -604,6 +622,8 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
"acc@5": 96.330, "acc@5": 96.330,
} }
}, },
"_ops": 8.473,
"_file_size": 150.701,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -628,6 +648,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -628,6 +648,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 95.240, "acc@5": 95.240,
} }
}, },
"_ops": 15.912,
"_file_size": 319.49,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -644,6 +666,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -644,6 +666,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 96.328, "acc@5": 96.328,
} }
}, },
"_ops": 15.912,
"_file_size": 319.49,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -665,6 +689,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -665,6 +689,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 98.054, "acc@5": 98.054,
} }
}, },
"_ops": 46.735,
"_file_size": 319.49,
"_docs": """ "_docs": """
These weights are learnt via transfer learning by end-to-end fine-tuning the original These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...@@ -686,6 +712,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -686,6 +712,8 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
"acc@5": 97.244, "acc@5": 97.244,
} }
}, },
"_ops": 15.912,
"_file_size": 319.49,
"_docs": """ "_docs": """
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...@@ -709,6 +737,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -709,6 +737,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 95.340, "acc@5": 95.340,
} }
}, },
"_ops": 32.28,
"_file_size": 554.076,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -725,6 +755,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -725,6 +755,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 96.498, "acc@5": 96.498,
} }
}, },
"_ops": 32.28,
"_file_size": 554.076,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -746,6 +778,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -746,6 +778,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 98.362, "acc@5": 98.362,
} }
}, },
"_ops": 94.826,
"_file_size": 554.076,
"_docs": """ "_docs": """
These weights are learnt via transfer learning by end-to-end fine-tuning the original These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...@@ -767,6 +801,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -767,6 +801,8 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
"acc@5": 97.480, "acc@5": 97.480,
} }
}, },
"_ops": 32.28,
"_file_size": 554.076,
"_docs": """ "_docs": """
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...@@ -791,6 +827,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum): ...@@ -791,6 +827,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5": 98.682, "acc@5": 98.682,
} }
}, },
"_ops": 374.57,
"_file_size": 2461.564,
"_docs": """ "_docs": """
These weights are learnt via transfer learning by end-to-end fine-tuning the original These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...@@ -812,6 +850,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum): ...@@ -812,6 +850,8 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
"acc@5": 97.844, "acc@5": 97.844,
} }
}, },
"_ops": 127.518,
"_file_size": 2461.564,
"_docs": """ "_docs": """
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...@@ -835,6 +875,8 @@ class RegNet_X_400MF_Weights(WeightsEnum): ...@@ -835,6 +875,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@5": 90.950, "acc@5": 90.950,
} }
}, },
"_ops": 0.414,
"_file_size": 21.258,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -851,6 +893,8 @@ class RegNet_X_400MF_Weights(WeightsEnum): ...@@ -851,6 +893,8 @@ class RegNet_X_400MF_Weights(WeightsEnum):
"acc@5": 92.322, "acc@5": 92.322,
} }
}, },
"_ops": 0.414,
"_file_size": 21.257,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -875,6 +919,8 @@ class RegNet_X_800MF_Weights(WeightsEnum): ...@@ -875,6 +919,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@5": 92.348, "acc@5": 92.348,
} }
}, },
"_ops": 0.8,
"_file_size": 27.945,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -891,6 +937,8 @@ class RegNet_X_800MF_Weights(WeightsEnum): ...@@ -891,6 +937,8 @@ class RegNet_X_800MF_Weights(WeightsEnum):
"acc@5": 93.826, "acc@5": 93.826,
} }
}, },
"_ops": 0.8,
"_file_size": 27.945,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -915,6 +963,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ...@@ -915,6 +963,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@5": 93.440, "acc@5": 93.440,
} }
}, },
"_ops": 1.603,
"_file_size": 35.339,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -931,6 +981,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ...@@ -931,6 +981,8 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
"acc@5": 94.922, "acc@5": 94.922,
} }
}, },
"_ops": 1.603,
"_file_size": 35.339,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -955,6 +1007,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ...@@ -955,6 +1007,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@5": 93.992, "acc@5": 93.992,
} }
}, },
"_ops": 3.177,
"_file_size": 58.756,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -971,6 +1025,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ...@@ -971,6 +1025,8 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
"acc@5": 95.430, "acc@5": 95.430,
} }
}, },
"_ops": 3.177,
"_file_size": 58.756,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -995,6 +1051,8 @@ class RegNet_X_8GF_Weights(WeightsEnum): ...@@ -995,6 +1051,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@5": 94.686, "acc@5": 94.686,
} }
}, },
"_ops": 7.995,
"_file_size": 151.456,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -1011,6 +1069,8 @@ class RegNet_X_8GF_Weights(WeightsEnum): ...@@ -1011,6 +1069,8 @@ class RegNet_X_8GF_Weights(WeightsEnum):
"acc@5": 95.678, "acc@5": 95.678,
} }
}, },
"_ops": 7.995,
"_file_size": 151.456,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -1035,6 +1095,8 @@ class RegNet_X_16GF_Weights(WeightsEnum): ...@@ -1035,6 +1095,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@5": 94.944, "acc@5": 94.944,
} }
}, },
"_ops": 15.941,
"_file_size": 207.627,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -1051,6 +1113,8 @@ class RegNet_X_16GF_Weights(WeightsEnum): ...@@ -1051,6 +1113,8 @@ class RegNet_X_16GF_Weights(WeightsEnum):
"acc@5": 96.196, "acc@5": 96.196,
} }
}, },
"_ops": 15.941,
"_file_size": 207.627,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -1075,6 +1139,8 @@ class RegNet_X_32GF_Weights(WeightsEnum): ...@@ -1075,6 +1139,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@5": 95.248, "acc@5": 95.248,
} }
}, },
"_ops": 31.736,
"_file_size": 412.039,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -1091,6 +1157,8 @@ class RegNet_X_32GF_Weights(WeightsEnum): ...@@ -1091,6 +1157,8 @@ class RegNet_X_32GF_Weights(WeightsEnum):
"acc@5": 96.288, "acc@5": 96.288,
} }
}, },
"_ops": 31.736,
"_file_size": 412.039,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using a modified version of TorchVision's These weights improve upon the results of the original paper by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -1501,27 +1569,3 @@ def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress: ...@@ -1501,27 +1569,3 @@ def regnet_x_32gf(*, weights: Optional[RegNet_X_32GF_Weights] = None, progress:
params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs) params = BlockParams.from_init_params(depth=23, w_0=320, w_a=69.86, w_m=2.0, group_width=168, **kwargs)
return _regnet(params, weights, progress, **kwargs) return _regnet(params, weights, progress, **kwargs)
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"regnet_y_400mf": RegNet_Y_400MF_Weights.IMAGENET1K_V1.url,
"regnet_y_800mf": RegNet_Y_800MF_Weights.IMAGENET1K_V1.url,
"regnet_y_1_6gf": RegNet_Y_1_6GF_Weights.IMAGENET1K_V1.url,
"regnet_y_3_2gf": RegNet_Y_3_2GF_Weights.IMAGENET1K_V1.url,
"regnet_y_8gf": RegNet_Y_8GF_Weights.IMAGENET1K_V1.url,
"regnet_y_16gf": RegNet_Y_16GF_Weights.IMAGENET1K_V1.url,
"regnet_y_32gf": RegNet_Y_32GF_Weights.IMAGENET1K_V1.url,
"regnet_x_400mf": RegNet_X_400MF_Weights.IMAGENET1K_V1.url,
"regnet_x_800mf": RegNet_X_800MF_Weights.IMAGENET1K_V1.url,
"regnet_x_1_6gf": RegNet_X_1_6GF_Weights.IMAGENET1K_V1.url,
"regnet_x_3_2gf": RegNet_X_3_2GF_Weights.IMAGENET1K_V1.url,
"regnet_x_8gf": RegNet_X_8GF_Weights.IMAGENET1K_V1.url,
"regnet_x_16gf": RegNet_X_16GF_Weights.IMAGENET1K_V1.url,
"regnet_x_32gf": RegNet_X_32GF_Weights.IMAGENET1K_V1.url,
}
)
...@@ -108,7 +108,7 @@ class BasicBlock(nn.Module): ...@@ -108,7 +108,7 @@ class BasicBlock(nn.Module):
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1) # while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to # This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
...@@ -298,7 +298,7 @@ def _resnet( ...@@ -298,7 +298,7 @@ def _resnet(
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -323,6 +323,8 @@ class ResNet18_Weights(WeightsEnum): ...@@ -323,6 +323,8 @@ class ResNet18_Weights(WeightsEnum):
"acc@5": 89.078, "acc@5": 89.078,
} }
}, },
"_ops": 1.814,
"_file_size": 44.661,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -343,6 +345,8 @@ class ResNet34_Weights(WeightsEnum): ...@@ -343,6 +345,8 @@ class ResNet34_Weights(WeightsEnum):
"acc@5": 91.420, "acc@5": 91.420,
} }
}, },
"_ops": 3.664,
"_file_size": 83.275,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -363,6 +367,8 @@ class ResNet50_Weights(WeightsEnum): ...@@ -363,6 +367,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@5": 92.862, "acc@5": 92.862,
} }
}, },
"_ops": 4.089,
"_file_size": 97.781,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -379,6 +385,8 @@ class ResNet50_Weights(WeightsEnum): ...@@ -379,6 +385,8 @@ class ResNet50_Weights(WeightsEnum):
"acc@5": 95.434, "acc@5": 95.434,
} }
}, },
"_ops": 4.089,
"_file_size": 97.79,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using TorchVision's `new training recipe These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -402,6 +410,8 @@ class ResNet101_Weights(WeightsEnum): ...@@ -402,6 +410,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@5": 93.546, "acc@5": 93.546,
} }
}, },
"_ops": 7.801,
"_file_size": 170.511,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -418,6 +428,8 @@ class ResNet101_Weights(WeightsEnum): ...@@ -418,6 +428,8 @@ class ResNet101_Weights(WeightsEnum):
"acc@5": 95.780, "acc@5": 95.780,
} }
}, },
"_ops": 7.801,
"_file_size": 170.53,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using TorchVision's `new training recipe These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -441,6 +453,8 @@ class ResNet152_Weights(WeightsEnum): ...@@ -441,6 +453,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@5": 94.046, "acc@5": 94.046,
} }
}, },
"_ops": 11.514,
"_file_size": 230.434,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -457,6 +471,8 @@ class ResNet152_Weights(WeightsEnum): ...@@ -457,6 +471,8 @@ class ResNet152_Weights(WeightsEnum):
"acc@5": 96.002, "acc@5": 96.002,
} }
}, },
"_ops": 11.514,
"_file_size": 230.474,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using TorchVision's `new training recipe These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -480,6 +496,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -480,6 +496,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@5": 93.698, "acc@5": 93.698,
} }
}, },
"_ops": 4.23,
"_file_size": 95.789,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -496,6 +514,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -496,6 +514,8 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
"acc@5": 95.340, "acc@5": 95.340,
} }
}, },
"_ops": 4.23,
"_file_size": 95.833,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using TorchVision's `new training recipe These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -519,6 +539,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -519,6 +539,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@5": 94.526, "acc@5": 94.526,
} }
}, },
"_ops": 16.414,
"_file_size": 339.586,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -535,6 +557,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -535,6 +557,8 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
"acc@5": 96.228, "acc@5": 96.228,
} }
}, },
"_ops": 16.414,
"_file_size": 339.673,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using TorchVision's `new training recipe These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -558,6 +582,8 @@ class ResNeXt101_64X4D_Weights(WeightsEnum): ...@@ -558,6 +582,8 @@ class ResNeXt101_64X4D_Weights(WeightsEnum):
"acc@5": 96.454, "acc@5": 96.454,
} }
}, },
"_ops": 15.46,
"_file_size": 319.318,
"_docs": """ "_docs": """
These weights were trained from scratch by using TorchVision's `new training recipe These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -581,6 +607,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -581,6 +607,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@5": 94.086, "acc@5": 94.086,
} }
}, },
"_ops": 11.398,
"_file_size": 131.82,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -597,6 +625,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -597,6 +625,8 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
"acc@5": 95.758, "acc@5": 95.758,
} }
}, },
"_ops": 11.398,
"_file_size": 263.124,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using TorchVision's `new training recipe These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -620,6 +650,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -620,6 +650,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@5": 94.284, "acc@5": 94.284,
} }
}, },
"_ops": 22.753,
"_file_size": 242.896,
"_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""",
}, },
) )
...@@ -636,6 +668,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -636,6 +668,8 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
"acc@5": 96.020, "acc@5": 96.020,
} }
}, },
"_ops": 22.753,
"_file_size": 484.747,
"_docs": """ "_docs": """
These weights improve upon the results of the original paper by using TorchVision's `new training recipe These weights improve upon the results of the original paper by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -648,7 +682,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -648,7 +682,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
@register_model() @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1))
def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__. """ResNet-18 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
Args: Args:
weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The
...@@ -674,7 +708,7 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru ...@@ -674,7 +708,7 @@ def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = Tru
@register_model() @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1))
def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__. """ResNet-34 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
Args: Args:
weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The
...@@ -700,7 +734,7 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru ...@@ -700,7 +734,7 @@ def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = Tru
@register_model() @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1))
def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__. """ResNet-50 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
.. note:: .. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3 The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...@@ -732,7 +766,7 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru ...@@ -732,7 +766,7 @@ def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = Tru
@register_model() @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1))
def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__. """ResNet-101 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
.. note:: .. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3 The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...@@ -764,7 +798,7 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T ...@@ -764,7 +798,7 @@ def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = T
@register_model() @register_model()
@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) @handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1))
def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet:
"""ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`__. """ResNet-152 from `Deep Residual Learning for Image Recognition <https://arxiv.org/abs/1512.03385>`__.
.. note:: .. note::
The bottleneck of TorchVision places the stride for downsampling to the second 3x3 The bottleneck of TorchVision places the stride for downsampling to the second 3x3
...@@ -949,22 +983,3 @@ def wide_resnet101_2( ...@@ -949,22 +983,3 @@ def wide_resnet101_2(
_ovewrite_named_param(kwargs, "width_per_group", 64 * 2) _ovewrite_named_param(kwargs, "width_per_group", 64 * 2)
return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs)
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"resnet18": ResNet18_Weights.IMAGENET1K_V1.url,
"resnet34": ResNet34_Weights.IMAGENET1K_V1.url,
"resnet50": ResNet50_Weights.IMAGENET1K_V1.url,
"resnet101": ResNet101_Weights.IMAGENET1K_V1.url,
"resnet152": ResNet152_Weights.IMAGENET1K_V1.url,
"resnext50_32x4d": ResNeXt50_32X4D_Weights.IMAGENET1K_V1.url,
"resnext101_32x8d": ResNeXt101_32X8D_Weights.IMAGENET1K_V1.url,
"wide_resnet50_2": Wide_ResNet50_2_Weights.IMAGENET1K_V1.url,
"wide_resnet101_2": Wide_ResNet101_2_Weights.IMAGENET1K_V1.url,
}
)
...@@ -152,6 +152,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): ...@@ -152,6 +152,8 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
"pixel_acc": 92.4, "pixel_acc": 92.4,
} }
}, },
"_ops": 178.722,
"_file_size": 160.515,
}, },
) )
DEFAULT = COCO_WITH_VOC_LABELS_V1 DEFAULT = COCO_WITH_VOC_LABELS_V1
...@@ -171,6 +173,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): ...@@ -171,6 +173,8 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
"pixel_acc": 92.4, "pixel_acc": 92.4,
} }
}, },
"_ops": 258.743,
"_file_size": 233.217,
}, },
) )
DEFAULT = COCO_WITH_VOC_LABELS_V1 DEFAULT = COCO_WITH_VOC_LABELS_V1
...@@ -190,6 +194,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -190,6 +194,8 @@ class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
"pixel_acc": 91.2, "pixel_acc": 91.2,
} }
}, },
"_ops": 10.452,
"_file_size": 42.301,
}, },
) )
DEFAULT = COCO_WITH_VOC_LABELS_V1 DEFAULT = COCO_WITH_VOC_LABELS_V1
...@@ -269,7 +275,7 @@ def deeplabv3_resnet50( ...@@ -269,7 +275,7 @@ def deeplabv3_resnet50(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -325,7 +331,7 @@ def deeplabv3_resnet101( ...@@ -325,7 +331,7 @@ def deeplabv3_resnet101(
model = _deeplabv3_resnet(backbone, num_classes, aux_loss) model = _deeplabv3_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -379,19 +385,6 @@ def deeplabv3_mobilenet_v3_large( ...@@ -379,19 +385,6 @@ def deeplabv3_mobilenet_v3_large(
model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss) model = _deeplabv3_mobilenetv3(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
model_urls = _ModelURLs(
{
"deeplabv3_resnet50_coco": DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.url,
"deeplabv3_resnet101_coco": DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1.url,
"deeplabv3_mobilenet_v3_large_coco": DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1.url,
}
)
...@@ -71,6 +71,8 @@ class FCN_ResNet50_Weights(WeightsEnum): ...@@ -71,6 +71,8 @@ class FCN_ResNet50_Weights(WeightsEnum):
"pixel_acc": 91.4, "pixel_acc": 91.4,
} }
}, },
"_ops": 152.717,
"_file_size": 135.009,
}, },
) )
DEFAULT = COCO_WITH_VOC_LABELS_V1 DEFAULT = COCO_WITH_VOC_LABELS_V1
...@@ -90,6 +92,8 @@ class FCN_ResNet101_Weights(WeightsEnum): ...@@ -90,6 +92,8 @@ class FCN_ResNet101_Weights(WeightsEnum):
"pixel_acc": 91.9, "pixel_acc": 91.9,
} }
}, },
"_ops": 232.738,
"_file_size": 207.711,
}, },
) )
DEFAULT = COCO_WITH_VOC_LABELS_V1 DEFAULT = COCO_WITH_VOC_LABELS_V1
...@@ -164,7 +168,7 @@ def fcn_resnet50( ...@@ -164,7 +168,7 @@ def fcn_resnet50(
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -223,18 +227,6 @@ def fcn_resnet101( ...@@ -223,18 +227,6 @@ def fcn_resnet101(
model = _fcn_resnet(backbone, num_classes, aux_loss) model = _fcn_resnet(backbone, num_classes, aux_loss)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
model_urls = _ModelURLs(
{
"fcn_resnet50_coco": FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1.url,
"fcn_resnet101_coco": FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1.url,
}
)
...@@ -108,6 +108,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -108,6 +108,8 @@ class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
"pixel_acc": 91.2, "pixel_acc": 91.2,
} }
}, },
"_ops": 2.086,
"_file_size": 12.49,
"_docs": """ "_docs": """
These weights were trained on a subset of COCO, using only the 20 categories that are present in the These weights were trained on a subset of COCO, using only the 20 categories that are present in the
Pascal VOC dataset. Pascal VOC dataset.
...@@ -171,17 +173,6 @@ def lraspp_mobilenet_v3_large( ...@@ -171,17 +173,6 @@ def lraspp_mobilenet_v3_large(
model = _lraspp_mobilenetv3(backbone, num_classes) model = _lraspp_mobilenetv3(backbone, num_classes)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
# The dictionary below is internal implementation detail and will be removed in v0.15
from .._utils import _ModelURLs
model_urls = _ModelURLs(
{
"lraspp_mobilenet_v3_large_coco": LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1.url,
}
)
...@@ -35,7 +35,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor: ...@@ -35,7 +35,7 @@ def channel_shuffle(x: Tensor, groups: int) -> Tensor:
x = torch.transpose(x, 1, 2).contiguous() x = torch.transpose(x, 1, 2).contiguous()
# flatten # flatten
x = x.view(batchsize, -1, height, width) x = x.view(batchsize, num_channels, height, width)
return x return x
...@@ -178,7 +178,7 @@ def _shufflenetv2( ...@@ -178,7 +178,7 @@ def _shufflenetv2(
model = ShuffleNetV2(*args, **kwargs) model = ShuffleNetV2(*args, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -204,6 +204,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): ...@@ -204,6 +204,8 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
"acc@5": 81.746, "acc@5": 81.746,
} }
}, },
"_ops": 0.04,
"_file_size": 5.282,
"_docs": """These weights were trained from scratch to reproduce closely the results of the paper.""", "_docs": """These weights were trained from scratch to reproduce closely the results of the paper.""",
}, },
) )
...@@ -224,6 +226,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum): ...@@ -224,6 +226,8 @@ class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
"acc@5": 88.316, "acc@5": 88.316,
} }
}, },
"_ops": 0.145,
"_file_size": 8.791,
"_docs": """These weights were trained from scratch to reproduce closely the results of the paper.""", "_docs": """These weights were trained from scratch to reproduce closely the results of the paper.""",
}, },
) )
...@@ -244,6 +248,8 @@ class ShuffleNet_V2_X1_5_Weights(WeightsEnum): ...@@ -244,6 +248,8 @@ class ShuffleNet_V2_X1_5_Weights(WeightsEnum):
"acc@5": 91.086, "acc@5": 91.086,
} }
}, },
"_ops": 0.296,
"_file_size": 13.557,
"_docs": """ "_docs": """
These weights were trained from scratch by using TorchVision's `new training recipe These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -267,6 +273,8 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum): ...@@ -267,6 +273,8 @@ class ShuffleNet_V2_X2_0_Weights(WeightsEnum):
"acc@5": 93.006, "acc@5": 93.006,
} }
}, },
"_ops": 0.583,
"_file_size": 28.433,
"_docs": """ "_docs": """
These weights were trained from scratch by using TorchVision's `new training recipe These weights were trained from scratch by using TorchVision's `new training recipe
<https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_. <https://pytorch.org/blog/how-to-train-state-of-the-art-models-using-torchvision-latest-primitives/>`_.
...@@ -398,17 +406,3 @@ def shufflenet_v2_x2_0( ...@@ -398,17 +406,3 @@ def shufflenet_v2_x2_0(
weights = ShuffleNet_V2_X2_0_Weights.verify(weights) weights = ShuffleNet_V2_X2_0_Weights.verify(weights)
return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) return _shufflenetv2(weights, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"shufflenetv2_x0.5": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1.url,
"shufflenetv2_x1.0": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1.url,
"shufflenetv2_x1.5": None,
"shufflenetv2_x2.0": None,
}
)
...@@ -109,7 +109,7 @@ def _squeezenet( ...@@ -109,7 +109,7 @@ def _squeezenet(
model = SqueezeNet(version, **kwargs) model = SqueezeNet(version, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -135,6 +135,8 @@ class SqueezeNet1_0_Weights(WeightsEnum): ...@@ -135,6 +135,8 @@ class SqueezeNet1_0_Weights(WeightsEnum):
"acc@5": 80.420, "acc@5": 80.420,
} }
}, },
"_ops": 0.819,
"_file_size": 4.778,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -154,6 +156,8 @@ class SqueezeNet1_1_Weights(WeightsEnum): ...@@ -154,6 +156,8 @@ class SqueezeNet1_1_Weights(WeightsEnum):
"acc@5": 80.624, "acc@5": 80.624,
} }
}, },
"_ops": 0.349,
"_file_size": 4.729,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -217,15 +221,3 @@ def squeezenet1_1( ...@@ -217,15 +221,3 @@ def squeezenet1_1(
""" """
weights = SqueezeNet1_1_Weights.verify(weights) weights = SqueezeNet1_1_Weights.verify(weights)
return _squeezenet("1_1", weights, progress, **kwargs) return _squeezenet("1_1", weights, progress, **kwargs)
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"squeezenet1_0": SqueezeNet1_0_Weights.IMAGENET1K_V1.url,
"squeezenet1_1": SqueezeNet1_1_Weights.IMAGENET1K_V1.url,
}
)
...@@ -126,7 +126,8 @@ def shifted_window_attention( ...@@ -126,7 +126,8 @@ def shifted_window_attention(
qkv_bias: Optional[Tensor] = None, qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None, proj_bias: Optional[Tensor] = None,
logit_scale: Optional[torch.Tensor] = None, logit_scale: Optional[torch.Tensor] = None,
): training: bool = True,
) -> Tensor:
""" """
Window based multi-head self attention (W-MSA) module with relative position bias. Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window. It supports both of shifted and non-shifted window.
...@@ -143,6 +144,7 @@ def shifted_window_attention( ...@@ -143,6 +144,7 @@ def shifted_window_attention(
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns: Returns:
Tensor[N, H, W, C]: The output tensor after shifted window attention. Tensor[N, H, W, C]: The output tensor after shifted window attention.
""" """
...@@ -207,11 +209,11 @@ def shifted_window_attention( ...@@ -207,11 +209,11 @@ def shifted_window_attention(
attn = attn.view(-1, num_heads, x.size(1), x.size(1)) attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1) attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout) attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C)
x = F.linear(x, proj_weight, proj_bias) x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout) x = F.dropout(x, p=dropout, training=training)
# reverse windows # reverse windows
x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C)
...@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module): ...@@ -286,7 +288,7 @@ class ShiftedWindowAttention(nn.Module):
self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type]
) )
def forward(self, x: Tensor): def forward(self, x: Tensor) -> Tensor:
""" """
Args: Args:
x (Tensor): Tensor with layout of [B, H, W, C] x (Tensor): Tensor with layout of [B, H, W, C]
...@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module): ...@@ -306,6 +308,7 @@ class ShiftedWindowAttention(nn.Module):
dropout=self.dropout, dropout=self.dropout,
qkv_bias=self.qkv.bias, qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias, proj_bias=self.proj.bias,
training=self.training,
) )
...@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention): ...@@ -391,6 +394,7 @@ class ShiftedWindowAttentionV2(ShiftedWindowAttention):
qkv_bias=self.qkv.bias, qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias, proj_bias=self.proj.bias,
logit_scale=self.logit_scale, logit_scale=self.logit_scale,
training=self.training,
) )
...@@ -494,6 +498,8 @@ class SwinTransformerBlockV2(SwinTransformerBlock): ...@@ -494,6 +498,8 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
) )
def forward(self, x: Tensor): def forward(self, x: Tensor):
# Here is the difference, we apply norm after the attention in V2.
# In V1 we applied norm before the attention.
x = x + self.stochastic_depth(self.norm1(self.attn(x))) x = x + self.stochastic_depth(self.norm1(self.attn(x)))
x = x + self.stochastic_depth(self.norm2(self.mlp(x))) x = x + self.stochastic_depth(self.norm2(self.mlp(x)))
return x return x
...@@ -502,7 +508,7 @@ class SwinTransformerBlockV2(SwinTransformerBlock): ...@@ -502,7 +508,7 @@ class SwinTransformerBlockV2(SwinTransformerBlock):
class SwinTransformer(nn.Module): class SwinTransformer(nn.Module):
""" """
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper. Shifted Windows" <https://arxiv.org/abs/2103.14030>`_ paper.
Args: Args:
patch_size (List[int]): Patch size. patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension. embed_dim (int): Patch embedding dimension.
...@@ -587,7 +593,7 @@ class SwinTransformer(nn.Module): ...@@ -587,7 +593,7 @@ class SwinTransformer(nn.Module):
num_features = embed_dim * 2 ** (len(depths) - 1) num_features = embed_dim * 2 ** (len(depths) - 1)
self.norm = norm_layer(num_features) self.norm = norm_layer(num_features)
self.permute = Permute([0, 3, 1, 2]) self.permute = Permute([0, 3, 1, 2]) # B H W C -> B C H W
self.avgpool = nn.AdaptiveAvgPool2d(1) self.avgpool = nn.AdaptiveAvgPool2d(1)
self.flatten = nn.Flatten(1) self.flatten = nn.Flatten(1)
self.head = nn.Linear(num_features, num_classes) self.head = nn.Linear(num_features, num_classes)
...@@ -633,7 +639,7 @@ def _swin_transformer( ...@@ -633,7 +639,7 @@ def _swin_transformer(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -660,6 +666,8 @@ class Swin_T_Weights(WeightsEnum): ...@@ -660,6 +666,8 @@ class Swin_T_Weights(WeightsEnum):
"acc@5": 95.776, "acc@5": 95.776,
} }
}, },
"_ops": 4.491,
"_file_size": 108.19,
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
}, },
) )
...@@ -683,6 +691,8 @@ class Swin_S_Weights(WeightsEnum): ...@@ -683,6 +691,8 @@ class Swin_S_Weights(WeightsEnum):
"acc@5": 96.360, "acc@5": 96.360,
} }
}, },
"_ops": 8.741,
"_file_size": 189.786,
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
}, },
) )
...@@ -706,6 +716,8 @@ class Swin_B_Weights(WeightsEnum): ...@@ -706,6 +716,8 @@ class Swin_B_Weights(WeightsEnum):
"acc@5": 96.640, "acc@5": 96.640,
} }
}, },
"_ops": 15.431,
"_file_size": 335.364,
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
}, },
) )
...@@ -729,6 +741,8 @@ class Swin_V2_T_Weights(WeightsEnum): ...@@ -729,6 +741,8 @@ class Swin_V2_T_Weights(WeightsEnum):
"acc@5": 96.132, "acc@5": 96.132,
} }
}, },
"_ops": 5.94,
"_file_size": 108.626,
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
}, },
) )
...@@ -752,6 +766,8 @@ class Swin_V2_S_Weights(WeightsEnum): ...@@ -752,6 +766,8 @@ class Swin_V2_S_Weights(WeightsEnum):
"acc@5": 96.816, "acc@5": 96.816,
} }
}, },
"_ops": 11.546,
"_file_size": 190.675,
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
}, },
) )
...@@ -775,6 +791,8 @@ class Swin_V2_B_Weights(WeightsEnum): ...@@ -775,6 +791,8 @@ class Swin_V2_B_Weights(WeightsEnum):
"acc@5": 96.864, "acc@5": 96.864,
} }
}, },
"_ops": 20.325,
"_file_size": 336.372,
"_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""",
}, },
) )
...@@ -786,7 +804,7 @@ class Swin_V2_B_Weights(WeightsEnum): ...@@ -786,7 +804,7 @@ class Swin_V2_B_Weights(WeightsEnum):
def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
""" """
Constructs a swin_tiny architecture from Constructs a swin_tiny architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
Args: Args:
weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The
...@@ -824,7 +842,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, * ...@@ -824,7 +842,7 @@ def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, *
def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
""" """
Constructs a swin_small architecture from Constructs a swin_small architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
Args: Args:
weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The
...@@ -862,7 +880,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, * ...@@ -862,7 +880,7 @@ def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, *
def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
""" """
Constructs a swin_base architecture from Constructs a swin_base architecture from
`Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/pdf/2103.14030>`_. `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows <https://arxiv.org/abs/2103.14030>`_.
Args: Args:
weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The
...@@ -900,7 +918,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, * ...@@ -900,7 +918,7 @@ def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, *
def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
""" """
Constructs a swin_v2_tiny architecture from Constructs a swin_v2_tiny architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
Args: Args:
weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The weights (:class:`~torchvision.models.Swin_V2_T_Weights`, optional): The
...@@ -940,7 +958,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T ...@@ -940,7 +958,7 @@ def swin_v2_t(*, weights: Optional[Swin_V2_T_Weights] = None, progress: bool = T
def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
""" """
Constructs a swin_v2_small architecture from Constructs a swin_v2_small architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
Args: Args:
weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The weights (:class:`~torchvision.models.Swin_V2_S_Weights`, optional): The
...@@ -980,7 +998,7 @@ def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = T ...@@ -980,7 +998,7 @@ def swin_v2_s(*, weights: Optional[Swin_V2_S_Weights] = None, progress: bool = T
def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: def swin_v2_b(*, weights: Optional[Swin_V2_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer:
""" """
Constructs a swin_v2_base architecture from Constructs a swin_v2_base architecture from
`Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/pdf/2111.09883>`_. `Swin Transformer V2: Scaling Up Capacity and Resolution <https://arxiv.org/abs/2111.09883>`_.
Args: Args:
weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The weights (:class:`~torchvision.models.Swin_V2_B_Weights`, optional): The
......
...@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b ...@@ -102,7 +102,7 @@ def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: b
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -127,6 +127,8 @@ class VGG11_Weights(WeightsEnum): ...@@ -127,6 +127,8 @@ class VGG11_Weights(WeightsEnum):
"acc@5": 88.628, "acc@5": 88.628,
} }
}, },
"_ops": 7.609,
"_file_size": 506.84,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -145,6 +147,8 @@ class VGG11_BN_Weights(WeightsEnum): ...@@ -145,6 +147,8 @@ class VGG11_BN_Weights(WeightsEnum):
"acc@5": 89.810, "acc@5": 89.810,
} }
}, },
"_ops": 7.609,
"_file_size": 506.881,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -163,6 +167,8 @@ class VGG13_Weights(WeightsEnum): ...@@ -163,6 +167,8 @@ class VGG13_Weights(WeightsEnum):
"acc@5": 89.246, "acc@5": 89.246,
} }
}, },
"_ops": 11.308,
"_file_size": 507.545,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -181,6 +187,8 @@ class VGG13_BN_Weights(WeightsEnum): ...@@ -181,6 +187,8 @@ class VGG13_BN_Weights(WeightsEnum):
"acc@5": 90.374, "acc@5": 90.374,
} }
}, },
"_ops": 11.308,
"_file_size": 507.59,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -199,6 +207,8 @@ class VGG16_Weights(WeightsEnum): ...@@ -199,6 +207,8 @@ class VGG16_Weights(WeightsEnum):
"acc@5": 90.382, "acc@5": 90.382,
} }
}, },
"_ops": 15.47,
"_file_size": 527.796,
}, },
) )
IMAGENET1K_FEATURES = Weights( IMAGENET1K_FEATURES = Weights(
...@@ -221,6 +231,8 @@ class VGG16_Weights(WeightsEnum): ...@@ -221,6 +231,8 @@ class VGG16_Weights(WeightsEnum):
"acc@5": float("nan"), "acc@5": float("nan"),
} }
}, },
"_ops": 15.47,
"_file_size": 527.802,
"_docs": """ "_docs": """
These weights can't be used for classification because they are missing values in the `classifier` These weights can't be used for classification because they are missing values in the `classifier`
module. Only the `features` module has valid values and can be used for feature extraction. The weights module. Only the `features` module has valid values and can be used for feature extraction. The weights
...@@ -244,6 +256,8 @@ class VGG16_BN_Weights(WeightsEnum): ...@@ -244,6 +256,8 @@ class VGG16_BN_Weights(WeightsEnum):
"acc@5": 91.516, "acc@5": 91.516,
} }
}, },
"_ops": 15.47,
"_file_size": 527.866,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -262,6 +276,8 @@ class VGG19_Weights(WeightsEnum): ...@@ -262,6 +276,8 @@ class VGG19_Weights(WeightsEnum):
"acc@5": 90.876, "acc@5": 90.876,
} }
}, },
"_ops": 19.632,
"_file_size": 548.051,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -280,6 +296,8 @@ class VGG19_BN_Weights(WeightsEnum): ...@@ -280,6 +296,8 @@ class VGG19_BN_Weights(WeightsEnum):
"acc@5": 91.842, "acc@5": 91.842,
} }
}, },
"_ops": 19.632,
"_file_size": 548.143,
}, },
) )
DEFAULT = IMAGENET1K_V1 DEFAULT = IMAGENET1K_V1
...@@ -491,21 +509,3 @@ def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = Tru ...@@ -491,21 +509,3 @@ def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = Tru
weights = VGG19_BN_Weights.verify(weights) weights = VGG19_BN_Weights.verify(weights)
return _vgg("E", True, weights, progress, **kwargs) return _vgg("E", True, weights, progress, **kwargs)
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"vgg11": VGG11_Weights.IMAGENET1K_V1.url,
"vgg13": VGG13_Weights.IMAGENET1K_V1.url,
"vgg16": VGG16_Weights.IMAGENET1K_V1.url,
"vgg19": VGG19_Weights.IMAGENET1K_V1.url,
"vgg11_bn": VGG11_BN_Weights.IMAGENET1K_V1.url,
"vgg13_bn": VGG13_BN_Weights.IMAGENET1K_V1.url,
"vgg16_bn": VGG16_BN_Weights.IMAGENET1K_V1.url,
"vgg19_bn": VGG19_BN_Weights.IMAGENET1K_V1.url,
}
)
from .mvit import * from .mvit import *
from .resnet import * from .resnet import *
from .s3d import * from .s3d import *
from .swin_transformer import *
...@@ -593,7 +593,7 @@ def _mvit( ...@@ -593,7 +593,7 @@ def _mvit(
) )
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -624,6 +624,8 @@ class MViT_V1_B_Weights(WeightsEnum): ...@@ -624,6 +624,8 @@ class MViT_V1_B_Weights(WeightsEnum):
"acc@5": 93.582, "acc@5": 93.582,
} }
}, },
"_ops": 70.599,
"_file_size": 139.764,
}, },
) )
DEFAULT = KINETICS400_V1 DEFAULT = KINETICS400_V1
...@@ -655,6 +657,8 @@ class MViT_V2_S_Weights(WeightsEnum): ...@@ -655,6 +657,8 @@ class MViT_V2_S_Weights(WeightsEnum):
"acc@5": 94.665, "acc@5": 94.665,
} }
}, },
"_ops": 64.224,
"_file_size": 131.884,
}, },
) )
DEFAULT = KINETICS400_V1 DEFAULT = KINETICS400_V1
...@@ -761,9 +765,10 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T ...@@ -761,9 +765,10 @@ def mvit_v1_b(*, weights: Optional[MViT_V1_B_Weights] = None, progress: bool = T
@register_model() @register_model()
@handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1)) @handle_legacy_interface(weights=("pretrained", MViT_V2_S_Weights.KINETICS400_V1))
def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT: def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = True, **kwargs: Any) -> MViT:
""" """Constructs a small MViTV2 architecture from
Constructs a small MViTV2 architecture from `Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__ and
`Multiscale Vision Transformers <https://arxiv.org/abs/2104.11227>`__. `MViTv2: Improved Multiscale Vision Transformers for Classification
and Detection <https://arxiv.org/abs/2112.01526>`__.
.. betastatus:: video module .. betastatus:: video module
...@@ -781,7 +786,7 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T ...@@ -781,7 +786,7 @@ def mvit_v2_s(*, weights: Optional[MViT_V2_S_Weights] = None, progress: bool = T
for more details about this class. for more details about this class.
.. autoclass:: torchvision.models.video.MViT_V2_S_Weights .. autoclass:: torchvision.models.video.MViT_V2_S_Weights
:members: :members:
""" """
weights = MViT_V2_S_Weights.verify(weights) weights = MViT_V2_S_Weights.verify(weights)
......
...@@ -303,7 +303,7 @@ def _video_resnet( ...@@ -303,7 +303,7 @@ def _video_resnet(
model = VideoResNet(block, conv_makers, layers, stem, **kwargs) model = VideoResNet(block, conv_makers, layers, stem, **kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -332,6 +332,8 @@ class R3D_18_Weights(WeightsEnum): ...@@ -332,6 +332,8 @@ class R3D_18_Weights(WeightsEnum):
"acc@5": 83.479, "acc@5": 83.479,
} }
}, },
"_ops": 40.697,
"_file_size": 127.359,
}, },
) )
DEFAULT = KINETICS400_V1 DEFAULT = KINETICS400_V1
...@@ -350,6 +352,8 @@ class MC3_18_Weights(WeightsEnum): ...@@ -350,6 +352,8 @@ class MC3_18_Weights(WeightsEnum):
"acc@5": 84.130, "acc@5": 84.130,
} }
}, },
"_ops": 43.343,
"_file_size": 44.672,
}, },
) )
DEFAULT = KINETICS400_V1 DEFAULT = KINETICS400_V1
...@@ -368,6 +372,8 @@ class R2Plus1D_18_Weights(WeightsEnum): ...@@ -368,6 +372,8 @@ class R2Plus1D_18_Weights(WeightsEnum):
"acc@5": 86.175, "acc@5": 86.175,
} }
}, },
"_ops": 40.519,
"_file_size": 120.318,
}, },
) )
DEFAULT = KINETICS400_V1 DEFAULT = KINETICS400_V1
......
...@@ -175,6 +175,8 @@ class S3D_Weights(WeightsEnum): ...@@ -175,6 +175,8 @@ class S3D_Weights(WeightsEnum):
"acc@5": 88.050, "acc@5": 88.050,
} }
}, },
"_ops": 17.979,
"_file_size": 31.972,
}, },
) )
DEFAULT = KINETICS400_V1 DEFAULT = KINETICS400_V1
...@@ -212,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg ...@@ -212,6 +214,6 @@ def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwarg
model = S3D(**kwargs) model = S3D(**kwargs)
if weights is not None: if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
# Modified from 2d Swin Transformers in torchvision:
# https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py
from functools import partial
from typing import Any, Callable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from ...transforms._presets import VideoClassification
from ...utils import _log_api_usage_once
from .._api import register_model, Weights, WeightsEnum
from .._meta import _KINETICS400_CATEGORIES
from .._utils import _ovewrite_named_param, handle_legacy_interface
from ..swin_transformer import PatchMerging, SwinTransformerBlock
__all__ = [
"SwinTransformer3d",
"Swin3D_T_Weights",
"Swin3D_S_Weights",
"Swin3D_B_Weights",
"swin3d_t",
"swin3d_s",
"swin3d_b",
]
def _get_window_and_shift_size(
shift_size: List[int], size_dhw: List[int], window_size: List[int]
) -> Tuple[List[int], List[int]]:
for i in range(3):
if size_dhw[i] <= window_size[i]:
# In this case, window_size will adapt to the input size, and no need to shift
window_size[i] = size_dhw[i]
shift_size[i] = 0
return window_size, shift_size
torch.fx.wrap("_get_window_and_shift_size")
def _get_relative_position_bias(
relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int]
) -> Tensor:
window_vol = window_size[0] * window_size[1] * window_size[2]
# In 3d case we flatten the relative_position_bias
relative_position_bias = relative_position_bias_table[
relative_position_index[:window_vol, :window_vol].flatten() # type: ignore[index]
]
relative_position_bias = relative_position_bias.view(window_vol, window_vol, -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0)
return relative_position_bias
torch.fx.wrap("_get_relative_position_bias")
def _compute_pad_size_3d(size_dhw: Tuple[int, int, int], patch_size: Tuple[int, int, int]) -> Tuple[int, int, int]:
pad_size = [(patch_size[i] - size_dhw[i] % patch_size[i]) % patch_size[i] for i in range(3)]
return pad_size[0], pad_size[1], pad_size[2]
torch.fx.wrap("_compute_pad_size_3d")
def _compute_attention_mask_3d(
x: Tensor,
size_dhw: Tuple[int, int, int],
window_size: Tuple[int, int, int],
shift_size: Tuple[int, int, int],
) -> Tensor:
# generate attention mask
attn_mask = x.new_zeros(*size_dhw)
num_windows = (size_dhw[0] // window_size[0]) * (size_dhw[1] // window_size[1]) * (size_dhw[2] // window_size[2])
slices = [
(
(0, -window_size[i]),
(-window_size[i], -shift_size[i]),
(-shift_size[i], None),
)
for i in range(3)
]
count = 0
for d in slices[0]:
for h in slices[1]:
for w in slices[2]:
attn_mask[d[0] : d[1], h[0] : h[1], w[0] : w[1]] = count
count += 1
# Partition window on attn_mask
attn_mask = attn_mask.view(
size_dhw[0] // window_size[0],
window_size[0],
size_dhw[1] // window_size[1],
window_size[1],
size_dhw[2] // window_size[2],
window_size[2],
)
attn_mask = attn_mask.permute(0, 2, 4, 1, 3, 5).reshape(
num_windows, window_size[0] * window_size[1] * window_size[2]
)
attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
torch.fx.wrap("_compute_attention_mask_3d")
def shifted_window_attention_3d(
input: Tensor,
qkv_weight: Tensor,
proj_weight: Tensor,
relative_position_bias: Tensor,
window_size: List[int],
num_heads: int,
shift_size: List[int],
attention_dropout: float = 0.0,
dropout: float = 0.0,
qkv_bias: Optional[Tensor] = None,
proj_bias: Optional[Tensor] = None,
training: bool = True,
) -> Tensor:
"""
Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
input (Tensor[B, T, H, W, C]): The input tensor, 5-dimensions.
qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value.
proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection.
relative_position_bias (Tensor): The learned relative position bias added to attention.
window_size (List[int]): 3-dimensions window size, T, H, W .
num_heads (int): Number of attention heads.
shift_size (List[int]): Shift size for shifted window attention (T, H, W).
attention_dropout (float): Dropout ratio of attention weight. Default: 0.0.
dropout (float): Dropout ratio of output. Default: 0.0.
qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None.
proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None.
training (bool, optional): Training flag used by the dropout parameters. Default: True.
Returns:
Tensor[B, T, H, W, C]: The output tensor after shifted window attention.
"""
b, t, h, w, c = input.shape
# pad feature maps to multiples of window size
pad_size = _compute_pad_size_3d((t, h, w), (window_size[0], window_size[1], window_size[2]))
x = F.pad(input, (0, 0, 0, pad_size[2], 0, pad_size[1], 0, pad_size[0]))
_, tp, hp, wp, _ = x.shape
padded_size = (tp, hp, wp)
# cyclic shift
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
# partition windows
num_windows = (
(padded_size[0] // window_size[0]) * (padded_size[1] // window_size[1]) * (padded_size[2] // window_size[2])
)
x = x.view(
b,
padded_size[0] // window_size[0],
window_size[0],
padded_size[1] // window_size[1],
window_size[1],
padded_size[2] // window_size[2],
window_size[2],
c,
)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).reshape(
b * num_windows, window_size[0] * window_size[1] * window_size[2], c
) # B*nW, Wd*Wh*Ww, C
# multi-head attention
qkv = F.linear(x, qkv_weight, qkv_bias)
qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, c // num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * (c // num_heads) ** -0.5
attn = q.matmul(k.transpose(-2, -1))
# add relative position bias
attn = attn + relative_position_bias
if sum(shift_size) > 0:
# generate attention mask to handle shifted windows with varying size
attn_mask = _compute_attention_mask_3d(
x,
(padded_size[0], padded_size[1], padded_size[2]),
(window_size[0], window_size[1], window_size[2]),
(shift_size[0], shift_size[1], shift_size[2]),
)
attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1))
attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, num_heads, x.size(1), x.size(1))
attn = F.softmax(attn, dim=-1)
attn = F.dropout(attn, p=attention_dropout, training=training)
x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), c)
x = F.linear(x, proj_weight, proj_bias)
x = F.dropout(x, p=dropout, training=training)
# reverse windows
x = x.view(
b,
padded_size[0] // window_size[0],
padded_size[1] // window_size[1],
padded_size[2] // window_size[2],
window_size[0],
window_size[1],
window_size[2],
c,
)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).reshape(b, tp, hp, wp, c)
# reverse cyclic shift
if sum(shift_size) > 0:
x = torch.roll(x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
# unpad features
x = x[:, :t, :h, :w, :].contiguous()
return x
torch.fx.wrap("shifted_window_attention_3d")
class ShiftedWindowAttention3d(nn.Module):
"""
See :func:`shifted_window_attention_3d`.
"""
def __init__(
self,
dim: int,
window_size: List[int],
shift_size: List[int],
num_heads: int,
qkv_bias: bool = True,
proj_bias: bool = True,
attention_dropout: float = 0.0,
dropout: float = 0.0,
) -> None:
super().__init__()
if len(window_size) != 3 or len(shift_size) != 3:
raise ValueError("window_size and shift_size must be of length 2")
self.window_size = window_size # Wd, Wh, Ww
self.shift_size = shift_size
self.num_heads = num_heads
self.attention_dropout = attention_dropout
self.dropout = dropout
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.define_relative_position_bias_table()
self.define_relative_position_index()
def define_relative_position_bias_table(self) -> None:
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros(
(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1),
self.num_heads,
)
) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)
def define_relative_position_index(self) -> None:
# get pair-wise relative position index for each token inside the window
coords_dhw = [torch.arange(self.window_size[i]) for i in range(3)]
coords = torch.stack(
torch.meshgrid(coords_dhw[0], coords_dhw[1], coords_dhw[2], indexing="ij")
) # 3, Wd, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 2] += self.window_size[2] - 1
relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1
# We don't flatten the relative_position_index here in 3d case.
relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
def get_relative_position_bias(self, window_size: List[int]) -> torch.Tensor:
return _get_relative_position_bias(self.relative_position_bias_table, self.relative_position_index, window_size) # type: ignore
def forward(self, x: Tensor) -> Tensor:
_, t, h, w, _ = x.shape
size_dhw = [t, h, w]
window_size, shift_size = self.window_size.copy(), self.shift_size.copy()
# Handle case where window_size is larger than the input tensor
window_size, shift_size = _get_window_and_shift_size(shift_size, size_dhw, window_size)
relative_position_bias = self.get_relative_position_bias(window_size)
return shifted_window_attention_3d(
x,
self.qkv.weight,
self.proj.weight,
relative_position_bias,
window_size,
self.num_heads,
shift_size=shift_size,
attention_dropout=self.attention_dropout,
dropout=self.dropout,
qkv_bias=self.qkv.bias,
proj_bias=self.proj.bias,
training=self.training,
)
# Modified from:
# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py
class PatchEmbed3d(nn.Module):
"""Video to Patch Embedding.
Args:
patch_size (List[int]): Patch token size.
in_channels (int): Number of input channels. Default: 3
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
patch_size: List[int],
in_channels: int = 3,
embed_dim: int = 96,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
_log_api_usage_once(self)
self.tuple_patch_size = (patch_size[0], patch_size[1], patch_size[2])
self.proj = nn.Conv3d(
in_channels,
embed_dim,
kernel_size=self.tuple_patch_size,
stride=self.tuple_patch_size,
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = nn.Identity()
def forward(self, x: Tensor) -> Tensor:
"""Forward function."""
# padding
_, _, t, h, w = x.size()
pad_size = _compute_pad_size_3d((t, h, w), self.tuple_patch_size)
x = F.pad(x, (0, pad_size[2], 0, pad_size[1], 0, pad_size[0]))
x = self.proj(x) # B C T Wh Ww
x = x.permute(0, 2, 3, 4, 1) # B T Wh Ww C
if self.norm is not None:
x = self.norm(x)
return x
class SwinTransformer3d(nn.Module):
"""
Implements 3D Swin Transformer from the `"Video Swin Transformer" <https://arxiv.org/abs/2106.13230>`_ paper.
Args:
patch_size (List[int]): Patch size.
embed_dim (int): Patch embedding dimension.
depths (List(int)): Depth of each Swin Transformer layer.
num_heads (List(int)): Number of attention heads in different layers.
window_size (List[int]): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1.
num_classes (int): Number of classes for classification head. Default: 400.
norm_layer (nn.Module, optional): Normalization layer. Default: None.
block (nn.Module, optional): SwinTransformer Block. Default: None.
downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging.
patch_embed (nn.Module, optional): Patch Embedding layer. Default: None.
"""
def __init__(
self,
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: List[int],
mlp_ratio: float = 4.0,
dropout: float = 0.0,
attention_dropout: float = 0.0,
stochastic_depth_prob: float = 0.1,
num_classes: int = 400,
norm_layer: Optional[Callable[..., nn.Module]] = None,
block: Optional[Callable[..., nn.Module]] = None,
downsample_layer: Callable[..., nn.Module] = PatchMerging,
patch_embed: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
_log_api_usage_once(self)
self.num_classes = num_classes
if block is None:
block = partial(SwinTransformerBlock, attn_layer=ShiftedWindowAttention3d)
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)
if patch_embed is None:
patch_embed = PatchEmbed3d
# split image into non-overlapping patches
self.patch_embed = patch_embed(patch_size=patch_size, embed_dim=embed_dim, norm_layer=norm_layer)
self.pos_drop = nn.Dropout(p=dropout)
layers: List[nn.Module] = []
total_stage_blocks = sum(depths)
stage_block_id = 0
# build SwinTransformer blocks
for i_stage in range(len(depths)):
stage: List[nn.Module] = []
dim = embed_dim * 2**i_stage
for i_layer in range(depths[i_stage]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1)
stage.append(
block(
dim,
num_heads[i_stage],
window_size=window_size,
shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size],
mlp_ratio=mlp_ratio,
dropout=dropout,
attention_dropout=attention_dropout,
stochastic_depth_prob=sd_prob,
norm_layer=norm_layer,
attn_layer=ShiftedWindowAttention3d,
)
)
stage_block_id += 1
layers.append(nn.Sequential(*stage))
# add patch merging layer
if i_stage < (len(depths) - 1):
layers.append(downsample_layer(dim, norm_layer))
self.features = nn.Sequential(*layers)
self.num_features = embed_dim * 2 ** (len(depths) - 1)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool3d(1)
self.head = nn.Linear(self.num_features, num_classes)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
# x: B C T H W
x = self.patch_embed(x) # B _T _H _W C
x = self.pos_drop(x)
x = self.features(x) # B _T _H _W C
x = self.norm(x)
x = x.permute(0, 4, 1, 2, 3) # B, C, _T, _H, _W
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.head(x)
return x
def _swin_transformer3d(
patch_size: List[int],
embed_dim: int,
depths: List[int],
num_heads: List[int],
window_size: List[int],
stochastic_depth_prob: float,
weights: Optional[WeightsEnum],
progress: bool,
**kwargs: Any,
) -> SwinTransformer3d:
if weights is not None:
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
model = SwinTransformer3d(
patch_size=patch_size,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
stochastic_depth_prob=stochastic_depth_prob,
**kwargs,
)
if weights is not None:
model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model
_COMMON_META = {
"categories": _KINETICS400_CATEGORIES,
"min_size": (1, 1),
"min_temporal_size": 1,
}
class Swin3D_T_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/swin3d_t-7615ae03.pth",
transforms=partial(
VideoClassification,
crop_size=(224, 224),
resize_size=(256,),
mean=(0.4850, 0.4560, 0.4060),
std=(0.2290, 0.2240, 0.2250),
),
meta={
**_COMMON_META,
"recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params": 28158070,
"_metrics": {
"Kinetics-400": {
"acc@1": 77.715,
"acc@5": 93.519,
}
},
"_ops": 43.882,
"_file_size": 121.543,
},
)
DEFAULT = KINETICS400_V1
class Swin3D_S_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/swin3d_s-da41c237.pth",
transforms=partial(
VideoClassification,
crop_size=(224, 224),
resize_size=(256,),
mean=(0.4850, 0.4560, 0.4060),
std=(0.2290, 0.2240, 0.2250),
),
meta={
**_COMMON_META,
"recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params": 49816678,
"_metrics": {
"Kinetics-400": {
"acc@1": 79.521,
"acc@5": 94.158,
}
},
"_ops": 82.841,
"_file_size": 218.288,
},
)
DEFAULT = KINETICS400_V1
class Swin3D_B_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/swin3d_b_1k-24f7c7c6.pth",
transforms=partial(
VideoClassification,
crop_size=(224, 224),
resize_size=(256,),
mean=(0.4850, 0.4560, 0.4060),
std=(0.2290, 0.2240, 0.2250),
),
meta={
**_COMMON_META,
"recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params": 88048984,
"_metrics": {
"Kinetics-400": {
"acc@1": 79.427,
"acc@5": 94.386,
}
},
"_ops": 140.667,
"_file_size": 364.134,
},
)
KINETICS400_IMAGENET22K_V1 = Weights(
url="https://download.pytorch.org/models/swin3d_b_22k-7c6ae6fa.pth",
transforms=partial(
VideoClassification,
crop_size=(224, 224),
resize_size=(256,),
mean=(0.4850, 0.4560, 0.4060),
std=(0.2290, 0.2240, 0.2250),
),
meta={
**_COMMON_META,
"recipe": "https://github.com/SwinTransformer/Video-Swin-Transformer#kinetics-400",
"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=15`, `clips_per_video=12`, and `clip_len=32`"
),
"num_params": 88048984,
"_metrics": {
"Kinetics-400": {
"acc@1": 81.643,
"acc@5": 95.574,
}
},
"_ops": 140.667,
"_file_size": 364.134,
},
)
DEFAULT = KINETICS400_V1
@register_model()
@handle_legacy_interface(weights=("pretrained", Swin3D_T_Weights.KINETICS400_V1))
def swin3d_t(*, weights: Optional[Swin3D_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d:
"""
Constructs a swin_tiny architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_T_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_T_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_T_Weights
:members:
"""
weights = Swin3D_T_Weights.verify(weights)
return _swin_transformer3d(
patch_size=[2, 4, 4],
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 7, 7],
stochastic_depth_prob=0.1,
weights=weights,
progress=progress,
**kwargs,
)
@register_model()
@handle_legacy_interface(weights=("pretrained", Swin3D_S_Weights.KINETICS400_V1))
def swin3d_s(*, weights: Optional[Swin3D_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d:
"""
Constructs a swin_small architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_S_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_S_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_S_Weights
:members:
"""
weights = Swin3D_S_Weights.verify(weights)
return _swin_transformer3d(
patch_size=[2, 4, 4],
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
window_size=[8, 7, 7],
stochastic_depth_prob=0.1,
weights=weights,
progress=progress,
**kwargs,
)
@register_model()
@handle_legacy_interface(weights=("pretrained", Swin3D_B_Weights.KINETICS400_V1))
def swin3d_b(*, weights: Optional[Swin3D_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer3d:
"""
Constructs a swin_base architecture from
`Video Swin Transformer <https://arxiv.org/abs/2106.13230>`_.
Args:
weights (:class:`~torchvision.models.video.Swin3D_B_Weights`, optional): The
pretrained weights to use. See
:class:`~torchvision.models.video.Swin3D_B_Weights` below for
more details, and possible values. By default, no pre-trained
weights are used.
progress (bool, optional): If True, displays a progress bar of the
download to stderr. Default is True.
**kwargs: parameters passed to the ``torchvision.models.video.swin_transformer.SwinTransformer``
base class. Please refer to the `source code
<https://github.com/pytorch/vision/blob/main/torchvision/models/video/swin_transformer.py>`_
for more details about this class.
.. autoclass:: torchvision.models.video.Swin3D_B_Weights
:members:
"""
weights = Swin3D_B_Weights.verify(weights)
return _swin_transformer3d(
patch_size=[2, 4, 4],
embed_dim=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=[8, 7, 7],
stochastic_depth_prob=0.1,
weights=weights,
progress=progress,
**kwargs,
)
...@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module): ...@@ -110,7 +110,7 @@ class EncoderBlock(nn.Module):
def forward(self, input: torch.Tensor): def forward(self, input: torch.Tensor):
torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}") torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
x = self.ln_1(input) x = self.ln_1(input)
x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False) x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x) x = self.dropout(x)
x = x + input x = x + input
...@@ -332,7 +332,7 @@ def _vision_transformer( ...@@ -332,7 +332,7 @@ def _vision_transformer(
) )
if weights: if weights:
model.load_state_dict(weights.get_state_dict(progress=progress)) model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True))
return model return model
...@@ -363,6 +363,8 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -363,6 +363,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 95.318, "acc@5": 95.318,
} }
}, },
"_ops": 17.564,
"_file_size": 330.285,
"_docs": """ "_docs": """
These weights were trained from scratch by using a modified version of `DeIT These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe. <https://arxiv.org/abs/2012.12877>`_'s training recipe.
...@@ -387,6 +389,8 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -387,6 +389,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 97.650, "acc@5": 97.650,
} }
}, },
"_ops": 55.484,
"_file_size": 331.398,
"_docs": """ "_docs": """
These weights are learnt via transfer learning by end-to-end fine-tuning the original These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...@@ -412,6 +416,8 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -412,6 +416,8 @@ class ViT_B_16_Weights(WeightsEnum):
"acc@5": 96.180, "acc@5": 96.180,
} }
}, },
"_ops": 17.564,
"_file_size": 330.285,
"_docs": """ "_docs": """
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...@@ -436,6 +442,8 @@ class ViT_B_32_Weights(WeightsEnum): ...@@ -436,6 +442,8 @@ class ViT_B_32_Weights(WeightsEnum):
"acc@5": 92.466, "acc@5": 92.466,
} }
}, },
"_ops": 4.409,
"_file_size": 336.604,
"_docs": """ "_docs": """
These weights were trained from scratch by using a modified version of `DeIT These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe. <https://arxiv.org/abs/2012.12877>`_'s training recipe.
...@@ -460,6 +468,8 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -460,6 +468,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 94.638, "acc@5": 94.638,
} }
}, },
"_ops": 61.555,
"_file_size": 1161.023,
"_docs": """ "_docs": """
These weights were trained from scratch by using a modified version of TorchVision's These weights were trained from scratch by using a modified version of TorchVision's
`new training recipe `new training recipe
...@@ -485,6 +495,8 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -485,6 +495,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 98.512, "acc@5": 98.512,
} }
}, },
"_ops": 361.986,
"_file_size": 1164.258,
"_docs": """ "_docs": """
These weights are learnt via transfer learning by end-to-end fine-tuning the original These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...@@ -510,6 +522,8 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -510,6 +522,8 @@ class ViT_L_16_Weights(WeightsEnum):
"acc@5": 97.422, "acc@5": 97.422,
} }
}, },
"_ops": 61.555,
"_file_size": 1161.023,
"_docs": """ "_docs": """
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...@@ -534,6 +548,8 @@ class ViT_L_32_Weights(WeightsEnum): ...@@ -534,6 +548,8 @@ class ViT_L_32_Weights(WeightsEnum):
"acc@5": 93.07, "acc@5": 93.07,
} }
}, },
"_ops": 15.378,
"_file_size": 1169.449,
"_docs": """ "_docs": """
These weights were trained from scratch by using a modified version of `DeIT These weights were trained from scratch by using a modified version of `DeIT
<https://arxiv.org/abs/2012.12877>`_'s training recipe. <https://arxiv.org/abs/2012.12877>`_'s training recipe.
...@@ -562,6 +578,8 @@ class ViT_H_14_Weights(WeightsEnum): ...@@ -562,6 +578,8 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5": 98.694, "acc@5": 98.694,
} }
}, },
"_ops": 1016.717,
"_file_size": 2416.643,
"_docs": """ "_docs": """
These weights are learnt via transfer learning by end-to-end fine-tuning the original These weights are learnt via transfer learning by end-to-end fine-tuning the original
`SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data. `SWAG <https://arxiv.org/abs/2201.08371>`_ weights on ImageNet-1K data.
...@@ -587,6 +605,8 @@ class ViT_H_14_Weights(WeightsEnum): ...@@ -587,6 +605,8 @@ class ViT_H_14_Weights(WeightsEnum):
"acc@5": 97.730, "acc@5": 97.730,
} }
}, },
"_ops": 167.295,
"_file_size": 2411.209,
"_docs": """ "_docs": """
These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk These weights are composed of the original frozen `SWAG <https://arxiv.org/abs/2201.08371>`_ trunk
weights and a linear classifier learnt on top of them trained on ImageNet-1K data. weights and a linear classifier learnt on top of them trained on ImageNet-1K data.
...@@ -773,7 +793,7 @@ def interpolate_embeddings( ...@@ -773,7 +793,7 @@ def interpolate_embeddings(
interpolation_mode: str = "bicubic", interpolation_mode: str = "bicubic",
reset_heads: bool = False, reset_heads: bool = False,
) -> "OrderedDict[str, torch.Tensor]": ) -> "OrderedDict[str, torch.Tensor]":
"""This function helps interpolating positional embeddings during checkpoint loading, """This function helps interpolate positional embeddings during checkpoint loading,
especially when you want to apply a pre-trained model on images with different resolution. especially when you want to apply a pre-trained model on images with different resolution.
Args: Args:
...@@ -798,7 +818,7 @@ def interpolate_embeddings( ...@@ -798,7 +818,7 @@ def interpolate_embeddings(
# We do this by reshaping the positions embeddings to a 2d grid, performing # We do this by reshaping the positions embeddings to a 2d grid, performing
# an interpolation in the (h, w) space and then reshaping back to a 1d grid. # an interpolation in the (h, w) space and then reshaping back to a 1d grid.
if new_seq_length != seq_length: if new_seq_length != seq_length:
# The class token embedding shouldn't be interpolated so we split it up. # The class token embedding shouldn't be interpolated, so we split it up.
seq_length -= 1 seq_length -= 1
new_seq_length -= 1 new_seq_length -= 1
pos_embedding_token = pos_embedding[:, :1, :] pos_embedding_token = pos_embedding[:, :1, :]
...@@ -842,17 +862,3 @@ def interpolate_embeddings( ...@@ -842,17 +862,3 @@ def interpolate_embeddings(
model_state = model_state_copy model_state = model_state_copy
return model_state return model_state
# The dictionary below is internal implementation detail and will be removed in v0.15
from ._utils import _ModelURLs
model_urls = _ModelURLs(
{
"vit_b_16": ViT_B_16_Weights.IMAGENET1K_V1.url,
"vit_b_32": ViT_B_32_Weights.IMAGENET1K_V1.url,
"vit_l_16": ViT_L_16_Weights.IMAGENET1K_V1.url,
"vit_l_32": ViT_L_32_Weights.IMAGENET1K_V1.url,
}
)
...@@ -50,7 +50,7 @@ def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor: ...@@ -50,7 +50,7 @@ def _box_xyxy_to_cxcywh(boxes: Tensor) -> Tensor:
def _box_xywh_to_xyxy(boxes: Tensor) -> Tensor: def _box_xywh_to_xyxy(boxes: Tensor) -> Tensor:
""" """
Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format. Converts bounding boxes from (x, y, w, h) format to (x1, y1, x2, y2) format.
(x, y) refers to top left of bouding box. (x, y) refers to top left of bounding box.
(w, h) refers to width and height of box. (w, h) refers to width and height of box.
Args: Args:
boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted. boxes (Tensor[N, 4]): boxes in (x, y, w, h) which will be converted.
......
...@@ -2,65 +2,106 @@ import sys ...@@ -2,65 +2,106 @@ import sys
import warnings import warnings
import torch import torch
from torch.onnx import symbolic_opset11 as opset11
from torch.onnx.symbolic_helper import parse_args
_onnx_opset_version = 11 _ONNX_OPSET_VERSION_11 = 11
_ONNX_OPSET_VERSION_16 = 16
BASE_ONNX_OPSET_VERSION = _ONNX_OPSET_VERSION_11
def _register_custom_op(): @parse_args("v", "v", "f")
from torch.onnx.symbolic_helper import parse_args def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze boxes = opset11.unsqueeze(g, boxes, 0)
from torch.onnx.symbolic_opset9 import _cast_Long scores = opset11.unsqueeze(g, opset11.unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
@parse_args("v", "v", "f") iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
boxes = unsqueeze(g, boxes, 0) # Cast boxes and scores to float32 in case they are float64 inputs
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0) nms_out = g.op(
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long)) "NonMaxSuppression",
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float)) g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold) g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1) max_output_per_class,
iou_threshold,
@parse_args("v", "v", "f", "i", "i", "i", "i") )
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): return opset11.squeeze(
batch_indices = _cast_Long( g, opset11.select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False )
)
rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
# TODO: Remove this warning after ONNX opset 16 is supported. def _process_batch_indices_for_roi_align(g, rois):
if aligned: indices = opset11.squeeze(
warnings.warn( g, opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. " )
"The workaround is that the user need apply the patch " return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)
"https://github.com/microsoft/onnxruntime/pull/8564 "
"and build ONNXRuntime from source."
) def _process_rois_for_roi_align(g, rois):
return opset11.select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
# ONNX doesn't support negative sampling_ratio
if sampling_ratio < 0:
warnings.warn( def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported." if sampling_ratio < 0:
) warnings.warn(
sampling_ratio = 0 "ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
return g.op( "The model will be exported with a sampling_ratio of 0."
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
) )
sampling_ratio = 0
return sampling_ratio
@parse_args("v", "v", "f", "i", "i") @parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width): def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
roi_pool = g.op( batch_indices = _process_batch_indices_for_roi_align(g, rois)
"MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale rois = _process_rois_for_roi_align(g, rois)
if aligned:
warnings.warn(
"ROIAlign with aligned=True is only supported in opset >= 16. "
"Please export with opset 16 or higher, or use aligned=False."
) )
return roi_pool, None sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
coordinate_transformation_mode_s=coordinate_transformation_mode,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)
from torch.onnx import register_custom_op_symbolic
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version) @parse_args("v", "v", "f", "i", "i")
register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version) def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version) roi_pool = g.op(
"MaxRoiPool", input, rois, pooled_shape_i=(pooled_height, pooled_width), spatial_scale_f=spatial_scale
)
return roi_pool, None
def _register_custom_op():
torch.onnx.register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _ONNX_OPSET_VERSION_11)
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _ONNX_OPSET_VERSION_11)
torch.onnx.register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _ONNX_OPSET_VERSION_16)
torch.onnx.register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _ONNX_OPSET_VERSION_11)
...@@ -63,9 +63,16 @@ def complete_box_iou_loss( ...@@ -63,9 +63,16 @@ def complete_box_iou_loss(
alpha = v / (1 - iou + v + eps) alpha = v / (1 - iou + v + eps)
loss = diou_loss + alpha * v loss = diou_loss + alpha * v
if reduction == "mean":
# Check reduction option and return loss accordingly
if reduction == "none":
pass
elif reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum": elif reduction == "sum":
loss = loss.sum() loss = loss.sum()
else:
raise ValueError(
f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
)
return loss return loss
...@@ -68,7 +68,7 @@ def deform_conv2d( ...@@ -68,7 +68,7 @@ def deform_conv2d(
use_mask = mask is not None use_mask = mask is not None
if mask is None: if mask is None:
mask = torch.zeros((input.shape[0], 0), device=input.device, dtype=input.dtype) mask = torch.zeros((input.shape[0], 1), device=input.device, dtype=input.dtype)
if bias is None: if bias is None:
bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype) bias = torch.zeros(out_channels, device=input.device, dtype=input.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment