Unverified Commit d8654bb0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor preset transforms (#5562)

* Refactor preset transforms

* Making presets public.
parent 2b5ab1bc
from functools import partial from functools import partial
from typing import Any, Optional, List from typing import Any, Optional, List
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
...@@ -51,7 +51,7 @@ _COMMON_META = { ...@@ -51,7 +51,7 @@ _COMMON_META = {
class MobileNet_V3_Large_Weights(WeightsEnum): class MobileNet_V3_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5483032, "num_params": 5483032,
...@@ -62,7 +62,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -62,7 +62,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth", url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5483032, "num_params": 5483032,
...@@ -77,7 +77,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum): ...@@ -77,7 +77,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
class MobileNet_V3_Small_Weights(WeightsEnum): class MobileNet_V3_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 2542856, "num_params": 2542856,
......
...@@ -4,7 +4,7 @@ from torch.nn.modules.batchnorm import BatchNorm2d ...@@ -4,7 +4,7 @@ from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
from torchvision.prototype.transforms import RaftEval from torchvision.prototype.transforms import OpticalFlowEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from .._api import WeightsEnum from .._api import WeightsEnum
...@@ -33,7 +33,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -33,7 +33,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_V1 = Weights( C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-things.pth) # Chairs + Things, ported from original paper repo (raft-things.pth)
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth", url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536, "num_params": 5257536,
...@@ -48,7 +48,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -48,7 +48,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_V2 = Weights( C_T_V2 = Weights(
# Chairs + Things # Chairs + Things
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth", url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536, "num_params": 5257536,
...@@ -63,7 +63,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -63,7 +63,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_SKHT_V1 = Weights( C_T_SKHT_V1 = Weights(
# Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth) # Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth", url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536, "num_params": 5257536,
...@@ -78,7 +78,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -78,7 +78,7 @@ class Raft_Large_Weights(WeightsEnum):
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean) # Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth", url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536, "num_params": 5257536,
...@@ -91,7 +91,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -91,7 +91,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_SKHT_K_V1 = Weights( C_T_SKHT_K_V1 = Weights(
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth) # Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth", url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536, "num_params": 5257536,
...@@ -106,7 +106,7 @@ class Raft_Large_Weights(WeightsEnum): ...@@ -106,7 +106,7 @@ class Raft_Large_Weights(WeightsEnum):
# Same as CT_SKHT with extra fine-tuning on Kitti # Same as CT_SKHT with extra fine-tuning on Kitti
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti # Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth", url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5257536, "num_params": 5257536,
...@@ -122,7 +122,7 @@ class Raft_Small_Weights(WeightsEnum): ...@@ -122,7 +122,7 @@ class Raft_Small_Weights(WeightsEnum):
C_T_V1 = Weights( C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-small.pth) # Chairs + Things, ported from original paper repo (raft-small.pth)
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth", url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 990162, "num_params": 990162,
...@@ -136,7 +136,7 @@ class Raft_Small_Weights(WeightsEnum): ...@@ -136,7 +136,7 @@ class Raft_Small_Weights(WeightsEnum):
C_T_V2 = Weights( C_T_V2 = Weights(
# Chairs + Things # Chairs + Things
url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth", url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
transforms=RaftEval, transforms=OpticalFlowEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 990162, "num_params": 990162,
......
...@@ -2,7 +2,7 @@ import warnings ...@@ -2,7 +2,7 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.googlenet import ( from ....models.quantization.googlenet import (
...@@ -26,7 +26,7 @@ __all__ = [ ...@@ -26,7 +26,7 @@ __all__ = [
class GoogLeNet_QuantizedWeights(WeightsEnum): class GoogLeNet_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth", url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
"task": "image_classification", "task": "image_classification",
"architecture": "GoogLeNet", "architecture": "GoogLeNet",
......
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.inception import ( from ....models.quantization.inception import (
...@@ -25,7 +25,7 @@ __all__ = [ ...@@ -25,7 +25,7 @@ __all__ = [
class Inception_V3_QuantizedWeights(WeightsEnum): class Inception_V3_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth", url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342),
meta={ meta={
"task": "image_classification", "task": "image_classification",
"architecture": "InceptionV3", "architecture": "InceptionV3",
......
from functools import partial from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv2 import ( from ....models.quantization.mobilenetv2 import (
...@@ -26,7 +26,7 @@ __all__ = [ ...@@ -26,7 +26,7 @@ __all__ = [
class MobileNet_V2_QuantizedWeights(WeightsEnum): class MobileNet_V2_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights( IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
"task": "image_classification", "task": "image_classification",
"architecture": "MobileNetV2", "architecture": "MobileNetV2",
......
...@@ -2,7 +2,7 @@ from functools import partial ...@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
import torch import torch
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv3 import ( from ....models.quantization.mobilenetv3 import (
...@@ -59,7 +59,7 @@ def _mobilenet_v3_model( ...@@ -59,7 +59,7 @@ def _mobilenet_v3_model(
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum): class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights( IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
"task": "image_classification", "task": "image_classification",
"architecture": "MobileNetV3", "architecture": "MobileNetV3",
......
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.resnet import ( from ....models.quantization.resnet import (
...@@ -68,7 +68,7 @@ _COMMON_META = { ...@@ -68,7 +68,7 @@ _COMMON_META = {
class ResNet18_QuantizedWeights(WeightsEnum): class ResNet18_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth", url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -85,7 +85,7 @@ class ResNet18_QuantizedWeights(WeightsEnum): ...@@ -85,7 +85,7 @@ class ResNet18_QuantizedWeights(WeightsEnum):
class ResNet50_QuantizedWeights(WeightsEnum): class ResNet50_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -98,7 +98,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): ...@@ -98,7 +98,7 @@ class ResNet50_QuantizedWeights(WeightsEnum):
) )
IMAGENET1K_FBGEMM_V2 = Weights( IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth", url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -115,7 +115,7 @@ class ResNet50_QuantizedWeights(WeightsEnum): ...@@ -115,7 +115,7 @@ class ResNet50_QuantizedWeights(WeightsEnum):
class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth", url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt", "architecture": "ResNeXt",
...@@ -128,7 +128,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum): ...@@ -128,7 +128,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
) )
IMAGENET1K_FBGEMM_V2 = Weights( IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth", url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt", "architecture": "ResNeXt",
......
from functools import partial from functools import partial
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.shufflenetv2 import ( from ....models.quantization.shufflenetv2 import (
...@@ -67,7 +67,7 @@ _COMMON_META = { ...@@ -67,7 +67,7 @@ _COMMON_META = {
class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth", url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 1366792, "num_params": 1366792,
...@@ -82,7 +82,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum): ...@@ -82,7 +82,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum): class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights( IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth", url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 2278604, "num_params": 2278604,
......
...@@ -2,7 +2,7 @@ from functools import partial ...@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams from ...models.regnet import RegNet, BlockParams
...@@ -77,7 +77,7 @@ def _regnet( ...@@ -77,7 +77,7 @@ def _regnet(
class RegNet_Y_400MF_Weights(WeightsEnum): class RegNet_Y_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 4344144, "num_params": 4344144,
...@@ -88,7 +88,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ...@@ -88,7 +88,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth", url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 4344144, "num_params": 4344144,
...@@ -103,7 +103,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum): ...@@ -103,7 +103,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
class RegNet_Y_800MF_Weights(WeightsEnum): class RegNet_Y_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 6432512, "num_params": 6432512,
...@@ -114,7 +114,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ...@@ -114,7 +114,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth", url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 6432512, "num_params": 6432512,
...@@ -129,7 +129,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum): ...@@ -129,7 +129,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
class RegNet_Y_1_6GF_Weights(WeightsEnum): class RegNet_Y_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 11202430, "num_params": 11202430,
...@@ -140,7 +140,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ...@@ -140,7 +140,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth", url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 11202430, "num_params": 11202430,
...@@ -155,7 +155,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum): ...@@ -155,7 +155,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
class RegNet_Y_3_2GF_Weights(WeightsEnum): class RegNet_Y_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 19436338, "num_params": 19436338,
...@@ -166,7 +166,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ...@@ -166,7 +166,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth", url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 19436338, "num_params": 19436338,
...@@ -181,7 +181,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum): ...@@ -181,7 +181,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
class RegNet_Y_8GF_Weights(WeightsEnum): class RegNet_Y_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 39381472, "num_params": 39381472,
...@@ -192,7 +192,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ...@@ -192,7 +192,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth", url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 39381472, "num_params": 39381472,
...@@ -207,7 +207,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum): ...@@ -207,7 +207,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
class RegNet_Y_16GF_Weights(WeightsEnum): class RegNet_Y_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 83590140, "num_params": 83590140,
...@@ -218,7 +218,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -218,7 +218,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth", url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 83590140, "num_params": 83590140,
...@@ -233,7 +233,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum): ...@@ -233,7 +233,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
class RegNet_Y_32GF_Weights(WeightsEnum): class RegNet_Y_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 145046770, "num_params": 145046770,
...@@ -244,7 +244,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum): ...@@ -244,7 +244,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth", url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 145046770, "num_params": 145046770,
...@@ -264,7 +264,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum): ...@@ -264,7 +264,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
class RegNet_X_400MF_Weights(WeightsEnum): class RegNet_X_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5495976, "num_params": 5495976,
...@@ -275,7 +275,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): ...@@ -275,7 +275,7 @@ class RegNet_X_400MF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth", url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 5495976, "num_params": 5495976,
...@@ -290,7 +290,7 @@ class RegNet_X_400MF_Weights(WeightsEnum): ...@@ -290,7 +290,7 @@ class RegNet_X_400MF_Weights(WeightsEnum):
class RegNet_X_800MF_Weights(WeightsEnum): class RegNet_X_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 7259656, "num_params": 7259656,
...@@ -301,7 +301,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): ...@@ -301,7 +301,7 @@ class RegNet_X_800MF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth", url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 7259656, "num_params": 7259656,
...@@ -316,7 +316,7 @@ class RegNet_X_800MF_Weights(WeightsEnum): ...@@ -316,7 +316,7 @@ class RegNet_X_800MF_Weights(WeightsEnum):
class RegNet_X_1_6GF_Weights(WeightsEnum): class RegNet_X_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 9190136, "num_params": 9190136,
...@@ -327,7 +327,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ...@@ -327,7 +327,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth", url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 9190136, "num_params": 9190136,
...@@ -342,7 +342,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum): ...@@ -342,7 +342,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
class RegNet_X_3_2GF_Weights(WeightsEnum): class RegNet_X_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 15296552, "num_params": 15296552,
...@@ -353,7 +353,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ...@@ -353,7 +353,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth", url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 15296552, "num_params": 15296552,
...@@ -368,7 +368,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum): ...@@ -368,7 +368,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
class RegNet_X_8GF_Weights(WeightsEnum): class RegNet_X_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 39572648, "num_params": 39572648,
...@@ -379,7 +379,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): ...@@ -379,7 +379,7 @@ class RegNet_X_8GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth", url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 39572648, "num_params": 39572648,
...@@ -394,7 +394,7 @@ class RegNet_X_8GF_Weights(WeightsEnum): ...@@ -394,7 +394,7 @@ class RegNet_X_8GF_Weights(WeightsEnum):
class RegNet_X_16GF_Weights(WeightsEnum): class RegNet_X_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 54278536, "num_params": 54278536,
...@@ -405,7 +405,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): ...@@ -405,7 +405,7 @@ class RegNet_X_16GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth", url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 54278536, "num_params": 54278536,
...@@ -420,7 +420,7 @@ class RegNet_X_16GF_Weights(WeightsEnum): ...@@ -420,7 +420,7 @@ class RegNet_X_16GF_Weights(WeightsEnum):
class RegNet_X_32GF_Weights(WeightsEnum): class RegNet_X_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 107811560, "num_params": 107811560,
...@@ -431,7 +431,7 @@ class RegNet_X_32GF_Weights(WeightsEnum): ...@@ -431,7 +431,7 @@ class RegNet_X_32GF_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth", url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 107811560, "num_params": 107811560,
......
from functools import partial from functools import partial
from typing import Any, List, Optional, Type, Union from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet from ...models.resnet import BasicBlock, Bottleneck, ResNet
...@@ -63,7 +63,7 @@ _COMMON_META = { ...@@ -63,7 +63,7 @@ _COMMON_META = {
class ResNet18_Weights(WeightsEnum): class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth", url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -80,7 +80,7 @@ class ResNet18_Weights(WeightsEnum): ...@@ -80,7 +80,7 @@ class ResNet18_Weights(WeightsEnum):
class ResNet34_Weights(WeightsEnum): class ResNet34_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet34-b627a593.pth", url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -97,7 +97,7 @@ class ResNet34_Weights(WeightsEnum): ...@@ -97,7 +97,7 @@ class ResNet34_Weights(WeightsEnum):
class ResNet50_Weights(WeightsEnum): class ResNet50_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet50-0676ba61.pth", url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -110,7 +110,7 @@ class ResNet50_Weights(WeightsEnum): ...@@ -110,7 +110,7 @@ class ResNet50_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -127,7 +127,7 @@ class ResNet50_Weights(WeightsEnum): ...@@ -127,7 +127,7 @@ class ResNet50_Weights(WeightsEnum):
class ResNet101_Weights(WeightsEnum): class ResNet101_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet101-63fe2227.pth", url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -140,7 +140,7 @@ class ResNet101_Weights(WeightsEnum): ...@@ -140,7 +140,7 @@ class ResNet101_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -157,7 +157,7 @@ class ResNet101_Weights(WeightsEnum): ...@@ -157,7 +157,7 @@ class ResNet101_Weights(WeightsEnum):
class ResNet152_Weights(WeightsEnum): class ResNet152_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet152-394f9c45.pth", url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -170,7 +170,7 @@ class ResNet152_Weights(WeightsEnum): ...@@ -170,7 +170,7 @@ class ResNet152_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth", url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNet", "architecture": "ResNet",
...@@ -187,7 +187,7 @@ class ResNet152_Weights(WeightsEnum): ...@@ -187,7 +187,7 @@ class ResNet152_Weights(WeightsEnum):
class ResNeXt50_32X4D_Weights(WeightsEnum): class ResNeXt50_32X4D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt", "architecture": "ResNeXt",
...@@ -200,7 +200,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -200,7 +200,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt", "architecture": "ResNeXt",
...@@ -217,7 +217,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum): ...@@ -217,7 +217,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
class ResNeXt101_32X8D_Weights(WeightsEnum): class ResNeXt101_32X8D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt", "architecture": "ResNeXt",
...@@ -230,7 +230,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -230,7 +230,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "ResNeXt", "architecture": "ResNeXt",
...@@ -247,7 +247,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum): ...@@ -247,7 +247,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
class Wide_ResNet50_2_Weights(WeightsEnum): class Wide_ResNet50_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet", "architecture": "WideResNet",
...@@ -260,7 +260,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -260,7 +260,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet", "architecture": "WideResNet",
...@@ -277,7 +277,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum): ...@@ -277,7 +277,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
class Wide_ResNet101_2_Weights(WeightsEnum): class Wide_ResNet101_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet", "architecture": "WideResNet",
...@@ -290,7 +290,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum): ...@@ -290,7 +290,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "WideResNet", "architecture": "WideResNet",
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import VocEval from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
...@@ -36,7 +36,7 @@ _COMMON_META = { ...@@ -36,7 +36,7 @@ _COMMON_META = {
class DeepLabV3_ResNet50_Weights(WeightsEnum): class DeepLabV3_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights( COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth", url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 42004074, "num_params": 42004074,
...@@ -51,7 +51,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum): ...@@ -51,7 +51,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
class DeepLabV3_ResNet101_Weights(WeightsEnum): class DeepLabV3_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights( COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth", url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 60996202, "num_params": 60996202,
...@@ -66,7 +66,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum): ...@@ -66,7 +66,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum): class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights( COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth", url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 11029328, "num_params": 11029328,
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import VocEval from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet from ....models.segmentation.fcn import FCN, _fcn_resnet
...@@ -26,7 +26,7 @@ _COMMON_META = { ...@@ -26,7 +26,7 @@ _COMMON_META = {
class FCN_ResNet50_Weights(WeightsEnum): class FCN_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights( COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth", url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 35322218, "num_params": 35322218,
...@@ -41,7 +41,7 @@ class FCN_ResNet50_Weights(WeightsEnum): ...@@ -41,7 +41,7 @@ class FCN_ResNet50_Weights(WeightsEnum):
class FCN_ResNet101_Weights(WeightsEnum): class FCN_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights( COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth", url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 54314346, "num_params": 54314346,
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import VocEval from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
...@@ -17,7 +17,7 @@ __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_l ...@@ -17,7 +17,7 @@ __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_l
class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum): class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights( COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth", url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(VocEval, resize_size=520), transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={ meta={
"task": "image_semantic_segmentation", "task": "image_semantic_segmentation",
"architecture": "LRASPP", "architecture": "LRASPP",
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2 from ...models.shufflenetv2 import ShuffleNetV2
...@@ -55,7 +55,7 @@ _COMMON_META = { ...@@ -55,7 +55,7 @@ _COMMON_META = {
class ShuffleNet_V2_X0_5_Weights(WeightsEnum): class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 1366792, "num_params": 1366792,
...@@ -69,7 +69,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum): ...@@ -69,7 +69,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
class ShuffleNet_V2_X1_0_Weights(WeightsEnum): class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 2278604, "num_params": 2278604,
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet from ...models.squeezenet import SqueezeNet
...@@ -27,7 +27,7 @@ _COMMON_META = { ...@@ -27,7 +27,7 @@ _COMMON_META = {
class SqueezeNet1_0_Weights(WeightsEnum): class SqueezeNet1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"min_size": (21, 21), "min_size": (21, 21),
...@@ -42,7 +42,7 @@ class SqueezeNet1_0_Weights(WeightsEnum): ...@@ -42,7 +42,7 @@ class SqueezeNet1_0_Weights(WeightsEnum):
class SqueezeNet1_1_Weights(WeightsEnum): class SqueezeNet1_1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"min_size": (17, 17), "min_size": (17, 17),
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs from ...models.vgg import VGG, make_layers, cfgs
...@@ -55,7 +55,7 @@ _COMMON_META = { ...@@ -55,7 +55,7 @@ _COMMON_META = {
class VGG11_Weights(WeightsEnum): class VGG11_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11-8a719046.pth", url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 132863336, "num_params": 132863336,
...@@ -69,7 +69,7 @@ class VGG11_Weights(WeightsEnum): ...@@ -69,7 +69,7 @@ class VGG11_Weights(WeightsEnum):
class VGG11_BN_Weights(WeightsEnum): class VGG11_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 132868840, "num_params": 132868840,
...@@ -83,7 +83,7 @@ class VGG11_BN_Weights(WeightsEnum): ...@@ -83,7 +83,7 @@ class VGG11_BN_Weights(WeightsEnum):
class VGG13_Weights(WeightsEnum): class VGG13_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13-19584684.pth", url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 133047848, "num_params": 133047848,
...@@ -97,7 +97,7 @@ class VGG13_Weights(WeightsEnum): ...@@ -97,7 +97,7 @@ class VGG13_Weights(WeightsEnum):
class VGG13_BN_Weights(WeightsEnum): class VGG13_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 133053736, "num_params": 133053736,
...@@ -111,7 +111,7 @@ class VGG13_BN_Weights(WeightsEnum): ...@@ -111,7 +111,7 @@ class VGG13_BN_Weights(WeightsEnum):
class VGG16_Weights(WeightsEnum): class VGG16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16-397923af.pth", url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 138357544, "num_params": 138357544,
...@@ -125,7 +125,10 @@ class VGG16_Weights(WeightsEnum): ...@@ -125,7 +125,10 @@ class VGG16_Weights(WeightsEnum):
IMAGENET1K_FEATURES = Weights( IMAGENET1K_FEATURES = Weights(
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
transforms=partial( transforms=partial(
ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0) ImageClassificationEval,
crop_size=224,
mean=(0.48235, 0.45882, 0.40784),
std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
), ),
meta={ meta={
**_COMMON_META, **_COMMON_META,
...@@ -142,7 +145,7 @@ class VGG16_Weights(WeightsEnum): ...@@ -142,7 +145,7 @@ class VGG16_Weights(WeightsEnum):
class VGG16_BN_Weights(WeightsEnum): class VGG16_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 138365992, "num_params": 138365992,
...@@ -156,7 +159,7 @@ class VGG16_BN_Weights(WeightsEnum): ...@@ -156,7 +159,7 @@ class VGG16_BN_Weights(WeightsEnum):
class VGG19_Weights(WeightsEnum): class VGG19_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 143667240, "num_params": 143667240,
...@@ -170,7 +173,7 @@ class VGG19_Weights(WeightsEnum): ...@@ -170,7 +173,7 @@ class VGG19_Weights(WeightsEnum):
class VGG19_BN_Weights(WeightsEnum): class VGG19_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 143678248, "num_params": 143678248,
......
...@@ -2,7 +2,7 @@ from functools import partial ...@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Type, Union from typing import Any, Callable, List, Optional, Sequence, Type, Union
from torch import nn from torch import nn
from torchvision.prototype.transforms import Kinect400Eval from torchvision.prototype.transforms import VideoClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.video.resnet import ( from ....models.video.resnet import (
...@@ -65,7 +65,7 @@ _COMMON_META = { ...@@ -65,7 +65,7 @@ _COMMON_META = {
class R3D_18_Weights(WeightsEnum): class R3D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights( KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "R3D", "architecture": "R3D",
...@@ -80,7 +80,7 @@ class R3D_18_Weights(WeightsEnum): ...@@ -80,7 +80,7 @@ class R3D_18_Weights(WeightsEnum):
class MC3_18_Weights(WeightsEnum): class MC3_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights( KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "MC3", "architecture": "MC3",
...@@ -95,7 +95,7 @@ class MC3_18_Weights(WeightsEnum): ...@@ -95,7 +95,7 @@ class MC3_18_Weights(WeightsEnum):
class R2Plus1D_18_Weights(WeightsEnum): class R2Plus1D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights( KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"architecture": "R(2+1)D", "architecture": "R(2+1)D",
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401 from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401
...@@ -38,7 +38,7 @@ _COMMON_META = { ...@@ -38,7 +38,7 @@ _COMMON_META = {
class ViT_B_16_Weights(WeightsEnum): class ViT_B_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16-c867db91.pth", url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 86567656, "num_params": 86567656,
...@@ -55,7 +55,7 @@ class ViT_B_16_Weights(WeightsEnum): ...@@ -55,7 +55,7 @@ class ViT_B_16_Weights(WeightsEnum):
class ViT_B_32_Weights(WeightsEnum): class ViT_B_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth", url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 88224232, "num_params": 88224232,
...@@ -72,7 +72,7 @@ class ViT_B_32_Weights(WeightsEnum): ...@@ -72,7 +72,7 @@ class ViT_B_32_Weights(WeightsEnum):
class ViT_L_16_Weights(WeightsEnum): class ViT_L_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth", url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=242), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 304326632, "num_params": 304326632,
...@@ -89,7 +89,7 @@ class ViT_L_16_Weights(WeightsEnum): ...@@ -89,7 +89,7 @@ class ViT_L_16_Weights(WeightsEnum):
class ViT_L_32_Weights(WeightsEnum): class ViT_L_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_32-c7638314.pth", url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 306535400, "num_params": 306535400,
......
...@@ -10,5 +10,11 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder ...@@ -10,5 +10,11 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._presets import (
ObjectDetectionEval,
ImageClassificationEval,
SemanticSegmentationEval,
VideoClassificationEval,
OpticalFlowEval,
)
from ._type_conversion import DecodeImage, LabelToOneHot from ._type_conversion import DecodeImage, LabelToOneHot
...@@ -6,10 +6,16 @@ from torch import Tensor, nn ...@@ -6,10 +6,16 @@ from torch import Tensor, nn
from ...transforms import functional as F, InterpolationMode from ...transforms import functional as F, InterpolationMode
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"] __all__ = [
"ObjectDetectionEval",
"ImageClassificationEval",
"VideoClassificationEval",
"SemanticSegmentationEval",
"OpticalFlowEval",
]
class CocoEval(nn.Module): class ObjectDetectionEval(nn.Module):
def forward( def forward(
self, img: Tensor, target: Optional[Dict[str, Tensor]] = None self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
...@@ -18,7 +24,7 @@ class CocoEval(nn.Module): ...@@ -18,7 +24,7 @@ class CocoEval(nn.Module):
return F.convert_image_dtype(img, torch.float), target return F.convert_image_dtype(img, torch.float), target
class ImageNetEval(nn.Module): class ImageClassificationEval(nn.Module):
def __init__( def __init__(
self, self,
crop_size: int, crop_size: int,
...@@ -44,7 +50,7 @@ class ImageNetEval(nn.Module): ...@@ -44,7 +50,7 @@ class ImageNetEval(nn.Module):
return img return img
class Kinect400Eval(nn.Module): class VideoClassificationEval(nn.Module):
def __init__( def __init__(
self, self,
crop_size: Tuple[int, int], crop_size: Tuple[int, int],
...@@ -69,7 +75,7 @@ class Kinect400Eval(nn.Module): ...@@ -69,7 +75,7 @@ class Kinect400Eval(nn.Module):
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
class VocEval(nn.Module): class SemanticSegmentationEval(nn.Module):
def __init__( def __init__(
self, self,
resize_size: int, resize_size: int,
...@@ -99,7 +105,7 @@ class VocEval(nn.Module): ...@@ -99,7 +105,7 @@ class VocEval(nn.Module):
return img, target return img, target
class RaftEval(nn.Module): class OpticalFlowEval(nn.Module):
def forward( def forward(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor] self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment