Unverified Commit d8654bb0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor preset transforms (#5562)

* Refactor preset transforms

* Making presets public.
parent 2b5ab1bc
from functools import partial
from typing import Any, Optional, List
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig
......@@ -51,7 +51,7 @@ _COMMON_META = {
class MobileNet_V3_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 5483032,
......@@ -62,7 +62,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 5483032,
......@@ -77,7 +77,7 @@ class MobileNet_V3_Large_Weights(WeightsEnum):
class MobileNet_V3_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2542856,
......
......@@ -4,7 +4,7 @@ from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.instancenorm import InstanceNorm2d
from torchvision.models.optical_flow import RAFT
from torchvision.models.optical_flow.raft import _raft, BottleneckBlock, ResidualBlock
from torchvision.prototype.transforms import RaftEval
from torchvision.prototype.transforms import OpticalFlowEval
from torchvision.transforms.functional import InterpolationMode
from .._api import WeightsEnum
......@@ -33,7 +33,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-things.pth)
url="https://download.pytorch.org/models/raft_large_C_T_V1-22a6c225.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
......@@ -48,7 +48,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_large_C_T_V2-1bb1363a.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
......@@ -63,7 +63,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_SKHT_V1 = Weights(
# Chairs + Things + Sintel fine-tuning, ported from original paper repo (raft-sintel.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V1-0b8c9e55.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
......@@ -78,7 +78,7 @@ class Raft_Large_Weights(WeightsEnum):
# Chairs + Things + (Sintel + Kitti + HD1K + Things_clean)
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_V2-ff5fadd5.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
......@@ -91,7 +91,7 @@ class Raft_Large_Weights(WeightsEnum):
C_T_SKHT_K_V1 = Weights(
# Chairs + Things + Sintel fine-tuning + Kitti fine-tuning, ported from the original repo (sintel-kitti.pth)
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V1-4a6a5039.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
......@@ -106,7 +106,7 @@ class Raft_Large_Weights(WeightsEnum):
# Same as CT_SKHT with extra fine-tuning on Kitti
# Corresponds to the C+T+S+K+H on paper with fine-tuning on Sintel and then on Kitti
url="https://download.pytorch.org/models/raft_large_C_T_SKHT_K_V2-b5c70766.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 5257536,
......@@ -122,7 +122,7 @@ class Raft_Small_Weights(WeightsEnum):
C_T_V1 = Weights(
# Chairs + Things, ported from original paper repo (raft-small.pth)
url="https://download.pytorch.org/models/raft_small_C_T_V1-ad48884c.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 990162,
......@@ -136,7 +136,7 @@ class Raft_Small_Weights(WeightsEnum):
C_T_V2 = Weights(
# Chairs + Things
url="https://download.pytorch.org/models/raft_small_C_T_V2-01064c6d.pth",
transforms=RaftEval,
transforms=OpticalFlowEval,
meta={
**_COMMON_META,
"num_params": 990162,
......
......@@ -2,7 +2,7 @@ import warnings
from functools import partial
from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.googlenet import (
......@@ -26,7 +26,7 @@ __all__ = [
class GoogLeNet_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/googlenet_fbgemm-c00238cf.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "GoogLeNet",
......
from functools import partial
from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.inception import (
......@@ -25,7 +25,7 @@ __all__ = [
class Inception_V3_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/inception_v3_google_fbgemm-71447a44.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342),
meta={
"task": "image_classification",
"architecture": "InceptionV3",
......
from functools import partial
from typing import Any, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv2 import (
......@@ -26,7 +26,7 @@ __all__ = [
class MobileNet_V2_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV2",
......
......@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, List, Optional, Union
import torch
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.mobilenetv3 import (
......@@ -59,7 +59,7 @@ def _mobilenet_v3_model(
class MobileNet_V3_Large_QuantizedWeights(WeightsEnum):
IMAGENET1K_QNNPACK_V1 = Weights(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "MobileNetV3",
......
from functools import partial
from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.resnet import (
......@@ -68,7 +68,7 @@ _COMMON_META = {
class ResNet18_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -85,7 +85,7 @@ class ResNet18_QuantizedWeights(WeightsEnum):
class ResNet50_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -98,7 +98,7 @@ class ResNet50_QuantizedWeights(WeightsEnum):
)
IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm-23753f79.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -115,7 +115,7 @@ class ResNet50_QuantizedWeights(WeightsEnum):
class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
......@@ -128,7 +128,7 @@ class ResNeXt101_32X8D_QuantizedWeights(WeightsEnum):
)
IMAGENET1K_FBGEMM_V2 = Weights(
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm-ee16d00c.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
......
from functools import partial
from typing import Any, List, Optional, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.quantization.shufflenetv2 import (
......@@ -67,7 +67,7 @@ _COMMON_META = {
class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 1366792,
......@@ -82,7 +82,7 @@ class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
IMAGENET1K_FBGEMM_V1 = Weights(
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2278604,
......
......@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.regnet import RegNet, BlockParams
......@@ -77,7 +77,7 @@ def _regnet(
class RegNet_Y_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 4344144,
......@@ -88,7 +88,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_400mf-e6988f5f.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 4344144,
......@@ -103,7 +103,7 @@ class RegNet_Y_400MF_Weights(WeightsEnum):
class RegNet_Y_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 6432512,
......@@ -114,7 +114,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_800mf-58fc7688.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 6432512,
......@@ -129,7 +129,7 @@ class RegNet_Y_800MF_Weights(WeightsEnum):
class RegNet_Y_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 11202430,
......@@ -140,7 +140,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_1_6gf-0d7bc02a.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 11202430,
......@@ -155,7 +155,7 @@ class RegNet_Y_1_6GF_Weights(WeightsEnum):
class RegNet_Y_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 19436338,
......@@ -166,7 +166,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_3_2gf-9180c971.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 19436338,
......@@ -181,7 +181,7 @@ class RegNet_Y_3_2GF_Weights(WeightsEnum):
class RegNet_Y_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 39381472,
......@@ -192,7 +192,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_8gf-dc2b1b54.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 39381472,
......@@ -207,7 +207,7 @@ class RegNet_Y_8GF_Weights(WeightsEnum):
class RegNet_Y_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 83590140,
......@@ -218,7 +218,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_16gf-3e4a00f9.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 83590140,
......@@ -233,7 +233,7 @@ class RegNet_Y_16GF_Weights(WeightsEnum):
class RegNet_Y_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 145046770,
......@@ -244,7 +244,7 @@ class RegNet_Y_32GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_y_32gf-8db6d4b5.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 145046770,
......@@ -264,7 +264,7 @@ class RegNet_Y_128GF_Weights(WeightsEnum):
class RegNet_X_400MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 5495976,
......@@ -275,7 +275,7 @@ class RegNet_X_400MF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_400mf-62229a5f.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 5495976,
......@@ -290,7 +290,7 @@ class RegNet_X_400MF_Weights(WeightsEnum):
class RegNet_X_800MF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 7259656,
......@@ -301,7 +301,7 @@ class RegNet_X_800MF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_800mf-94a99ebd.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 7259656,
......@@ -316,7 +316,7 @@ class RegNet_X_800MF_Weights(WeightsEnum):
class RegNet_X_1_6GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 9190136,
......@@ -327,7 +327,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_1_6gf-a12f2b72.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 9190136,
......@@ -342,7 +342,7 @@ class RegNet_X_1_6GF_Weights(WeightsEnum):
class RegNet_X_3_2GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 15296552,
......@@ -353,7 +353,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_3_2gf-7071aa85.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 15296552,
......@@ -368,7 +368,7 @@ class RegNet_X_3_2GF_Weights(WeightsEnum):
class RegNet_X_8GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 39572648,
......@@ -379,7 +379,7 @@ class RegNet_X_8GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_8gf-2b70d774.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 39572648,
......@@ -394,7 +394,7 @@ class RegNet_X_8GF_Weights(WeightsEnum):
class RegNet_X_16GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 54278536,
......@@ -405,7 +405,7 @@ class RegNet_X_16GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_16gf-ba3796d7.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 54278536,
......@@ -420,7 +420,7 @@ class RegNet_X_16GF_Weights(WeightsEnum):
class RegNet_X_32GF_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 107811560,
......@@ -431,7 +431,7 @@ class RegNet_X_32GF_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/regnet_x_32gf-6eb8fdc6.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 107811560,
......
from functools import partial
from typing import Any, List, Optional, Type, Union
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.resnet import BasicBlock, Bottleneck, ResNet
......@@ -63,7 +63,7 @@ _COMMON_META = {
class ResNet18_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet18-f37072fd.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -80,7 +80,7 @@ class ResNet18_Weights(WeightsEnum):
class ResNet34_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet34-b627a593.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -97,7 +97,7 @@ class ResNet34_Weights(WeightsEnum):
class ResNet50_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet50-0676ba61.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -110,7 +110,7 @@ class ResNet50_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -127,7 +127,7 @@ class ResNet50_Weights(WeightsEnum):
class ResNet101_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet101-63fe2227.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -140,7 +140,7 @@ class ResNet101_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet101-cd907fc2.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -157,7 +157,7 @@ class ResNet101_Weights(WeightsEnum):
class ResNet152_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnet152-394f9c45.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -170,7 +170,7 @@ class ResNet152_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnet152-f82ba261.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNet",
......@@ -187,7 +187,7 @@ class ResNet152_Weights(WeightsEnum):
class ResNeXt50_32X4D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
......@@ -200,7 +200,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
......@@ -217,7 +217,7 @@ class ResNeXt50_32X4D_Weights(WeightsEnum):
class ResNeXt101_32X8D_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
......@@ -230,7 +230,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "ResNeXt",
......@@ -247,7 +247,7 @@ class ResNeXt101_32X8D_Weights(WeightsEnum):
class Wide_ResNet50_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
......@@ -260,7 +260,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
......@@ -277,7 +277,7 @@ class Wide_ResNet50_2_Weights(WeightsEnum):
class Wide_ResNet101_2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"architecture": "WideResNet",
......@@ -290,7 +290,7 @@ class Wide_ResNet101_2_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"architecture": "WideResNet",
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import VocEval
from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
......@@ -36,7 +36,7 @@ _COMMON_META = {
class DeepLabV3_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
transforms=partial(VocEval, resize_size=520),
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 42004074,
......@@ -51,7 +51,7 @@ class DeepLabV3_ResNet50_Weights(WeightsEnum):
class DeepLabV3_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth",
transforms=partial(VocEval, resize_size=520),
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 60996202,
......@@ -66,7 +66,7 @@ class DeepLabV3_ResNet101_Weights(WeightsEnum):
class DeepLabV3_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/deeplabv3_mobilenet_v3_large-fc3c493d.pth",
transforms=partial(VocEval, resize_size=520),
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 11029328,
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import VocEval
from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.fcn import FCN, _fcn_resnet
......@@ -26,7 +26,7 @@ _COMMON_META = {
class FCN_ResNet50_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet50_coco-1167a1af.pth",
transforms=partial(VocEval, resize_size=520),
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 35322218,
......@@ -41,7 +41,7 @@ class FCN_ResNet50_Weights(WeightsEnum):
class FCN_ResNet101_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/fcn_resnet101_coco-7ecb50ca.pth",
transforms=partial(VocEval, resize_size=520),
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
**_COMMON_META,
"num_params": 54314346,
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import VocEval
from torchvision.prototype.transforms import SemanticSegmentationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
......@@ -17,7 +17,7 @@ __all__ = ["LRASPP", "LRASPP_MobileNet_V3_Large_Weights", "lraspp_mobilenet_v3_l
class LRASPP_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_WITH_VOC_LABELS_V1 = Weights(
url="https://download.pytorch.org/models/lraspp_mobilenet_v3_large-d234d4ea.pth",
transforms=partial(VocEval, resize_size=520),
transforms=partial(SemanticSegmentationEval, resize_size=520),
meta={
"task": "image_semantic_segmentation",
"architecture": "LRASPP",
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.shufflenetv2 import ShuffleNetV2
......@@ -55,7 +55,7 @@ _COMMON_META = {
class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 1366792,
......@@ -69,7 +69,7 @@ class ShuffleNet_V2_X0_5_Weights(WeightsEnum):
class ShuffleNet_V2_X1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2278604,
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.squeezenet import SqueezeNet
......@@ -27,7 +27,7 @@ _COMMON_META = {
class SqueezeNet1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"min_size": (21, 21),
......@@ -42,7 +42,7 @@ class SqueezeNet1_0_Weights(WeightsEnum):
class SqueezeNet1_1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"min_size": (17, 17),
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.vgg import VGG, make_layers, cfgs
......@@ -55,7 +55,7 @@ _COMMON_META = {
class VGG11_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11-8a719046.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132863336,
......@@ -69,7 +69,7 @@ class VGG11_Weights(WeightsEnum):
class VGG11_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 132868840,
......@@ -83,7 +83,7 @@ class VGG11_BN_Weights(WeightsEnum):
class VGG13_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13-19584684.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133047848,
......@@ -97,7 +97,7 @@ class VGG13_Weights(WeightsEnum):
class VGG13_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 133053736,
......@@ -111,7 +111,7 @@ class VGG13_BN_Weights(WeightsEnum):
class VGG16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16-397923af.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138357544,
......@@ -125,7 +125,10 @@ class VGG16_Weights(WeightsEnum):
IMAGENET1K_FEATURES = Weights(
url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth",
transforms=partial(
ImageNetEval, crop_size=224, mean=(0.48235, 0.45882, 0.40784), std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0)
ImageClassificationEval,
crop_size=224,
mean=(0.48235, 0.45882, 0.40784),
std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0),
),
meta={
**_COMMON_META,
......@@ -142,7 +145,7 @@ class VGG16_Weights(WeightsEnum):
class VGG16_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 138365992,
......@@ -156,7 +159,7 @@ class VGG16_BN_Weights(WeightsEnum):
class VGG19_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143667240,
......@@ -170,7 +173,7 @@ class VGG19_Weights(WeightsEnum):
class VGG19_BN_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 143678248,
......
......@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, Callable, List, Optional, Sequence, Type, Union
from torch import nn
from torchvision.prototype.transforms import Kinect400Eval
from torchvision.prototype.transforms import VideoClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ....models.video.resnet import (
......@@ -65,7 +65,7 @@ _COMMON_META = {
class R3D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R3D",
......@@ -80,7 +80,7 @@ class R3D_18_Weights(WeightsEnum):
class MC3_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "MC3",
......@@ -95,7 +95,7 @@ class MC3_18_Weights(WeightsEnum):
class R2Plus1D_18_Weights(WeightsEnum):
KINETICS400_V1 = Weights(
url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth",
transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)),
transforms=partial(VideoClassificationEval, crop_size=(112, 112), resize_size=(128, 171)),
meta={
**_COMMON_META,
"architecture": "R(2+1)D",
......
......@@ -5,7 +5,7 @@
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.vision_transformer import VisionTransformer, interpolate_embeddings # noqa: F401
......@@ -38,7 +38,7 @@ _COMMON_META = {
class ViT_B_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_16-c867db91.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 86567656,
......@@ -55,7 +55,7 @@ class ViT_B_16_Weights(WeightsEnum):
class ViT_B_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_b_32-d86f8d99.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 88224232,
......@@ -72,7 +72,7 @@ class ViT_B_32_Weights(WeightsEnum):
class ViT_L_16_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_16-852ce7e3.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=242),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=242),
meta={
**_COMMON_META,
"num_params": 304326632,
......@@ -89,7 +89,7 @@ class ViT_L_16_Weights(WeightsEnum):
class ViT_L_32_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/vit_l_32-c7638314.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 306535400,
......
......@@ -10,5 +10,11 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
from ._presets import (
ObjectDetectionEval,
ImageClassificationEval,
SemanticSegmentationEval,
VideoClassificationEval,
OpticalFlowEval,
)
from ._type_conversion import DecodeImage, LabelToOneHot
......@@ -6,10 +6,16 @@ from torch import Tensor, nn
from ...transforms import functional as F, InterpolationMode
__all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval", "RaftEval"]
__all__ = [
"ObjectDetectionEval",
"ImageClassificationEval",
"VideoClassificationEval",
"SemanticSegmentationEval",
"OpticalFlowEval",
]
class CocoEval(nn.Module):
class ObjectDetectionEval(nn.Module):
def forward(
self, img: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
......@@ -18,7 +24,7 @@ class CocoEval(nn.Module):
return F.convert_image_dtype(img, torch.float), target
class ImageNetEval(nn.Module):
class ImageClassificationEval(nn.Module):
def __init__(
self,
crop_size: int,
......@@ -44,7 +50,7 @@ class ImageNetEval(nn.Module):
return img
class Kinect400Eval(nn.Module):
class VideoClassificationEval(nn.Module):
def __init__(
self,
crop_size: Tuple[int, int],
......@@ -69,7 +75,7 @@ class Kinect400Eval(nn.Module):
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
class VocEval(nn.Module):
class SemanticSegmentationEval(nn.Module):
def __init__(
self,
resize_size: int,
......@@ -99,7 +105,7 @@ class VocEval(nn.Module):
return img, target
class RaftEval(nn.Module):
class OpticalFlowEval(nn.Module):
def forward(
self, img1: Tensor, img2: Tensor, flow: Optional[Tensor], valid_flow_mask: Optional[Tensor]
) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment