Unverified Commit d8654bb0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor preset transforms (#5562)

* Refactor preset transforms

* Making presets public.
parent 2b5ab1bc
......@@ -163,7 +163,7 @@ def load_data(traindir, valdir, args):
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.ImageNetEval(
preprocessing = prototype.transforms.ImageClassificationEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
......
......@@ -57,7 +57,7 @@ def get_transform(train, args):
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.CocoEval()
return prototype.transforms.ObjectDetectionEval()
def get_args_parser(add_help=True):
......
......@@ -137,7 +137,7 @@ def validate(model, args):
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.RaftEval()
preprocessing = prototype.transforms.OpticalFlowEval()
else:
preprocessing = OpticalFlowPresetEval()
......
......@@ -42,7 +42,7 @@ def get_transform(train, args):
weights = prototype.models.get_weight(args.weights)
return weights.transforms()
else:
return prototype.transforms.VocEval(resize_size=520)
return prototype.transforms.SemanticSegmentationEval(resize_size=520)
def criterion(inputs, target):
......
......@@ -157,7 +157,7 @@ def main(args):
weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms()
else:
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet
......@@ -16,7 +16,7 @@ __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
class AlexNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "AlexNet",
......
from functools import partial
from typing import Any, List, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.convnext import ConvNeXt, CNBlockConfig
......@@ -56,7 +56,7 @@ _COMMON_META = {
class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236),
meta={
**_COMMON_META,
"num_params": 28589128,
......@@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=230),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230),
meta={
**_COMMON_META,
"num_params": 50223688,
......@@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 88591464,
......@@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"num_params": 197767336,
......
......@@ -3,7 +3,7 @@ from functools import partial
from typing import Any, Optional, Tuple
import torch.nn as nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet
......@@ -78,7 +78,7 @@ _COMMON_META = {
class DenseNet121_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 7978856,
......@@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum):
class DenseNet161_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 28681000,
......@@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum):
class DenseNet169_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 14149480,
......@@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum):
class DenseNet201_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 20013928,
......
from typing import Any, Optional, Union
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.faster_rcnn import (
......@@ -43,7 +43,7 @@ _COMMON_META = {
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 41755286,
......@@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
......@@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 19386354,
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.fcos import (
......@@ -27,7 +27,7 @@ __all__ = [
class FCOS_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "FCOS",
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.keypoint_rcnn import (
......@@ -37,7 +37,7 @@ _COMMON_META = {
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_LEGACY = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 59137258,
......@@ -48,7 +48,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
)
COCO_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
**_COMMON_META,
"num_params": 59137258,
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.mask_rcnn import (
......@@ -27,7 +27,7 @@ __all__ = [
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "MaskRCNN",
......
from typing import Any, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.retinanet import (
......@@ -28,7 +28,7 @@ __all__ = [
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "RetinaNet",
......
import warnings
from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssd import (
......@@ -25,7 +25,7 @@ __all__ = [
class SSD300_VGG16_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "SSD",
......
......@@ -3,7 +3,7 @@ from functools import partial
from typing import Any, Callable, Optional
from torch import nn
from torchvision.prototype.transforms import CocoEval
from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssdlite import (
......@@ -30,7 +30,7 @@ __all__ = [
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval,
transforms=ObjectDetectionEval,
meta={
"task": "image_object_detection",
"architecture": "SSDLite",
......
......@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, Optional, Sequence, Union
from torch import nn
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf
......@@ -85,7 +85,9 @@ _COMMON_META_V2 = {
class EfficientNet_B0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 5288548,
......@@ -100,7 +102,9 @@ class EfficientNet_B0_Weights(WeightsEnum):
class EfficientNet_B1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 7794184,
......@@ -111,7 +115,9 @@ class EfficientNet_B1_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR),
transforms=partial(
ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
),
meta={
**_COMMON_META_V1,
"num_params": 7794184,
......@@ -128,7 +134,9 @@ class EfficientNet_B1_Weights(WeightsEnum):
class EfficientNet_B2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 9109994,
......@@ -143,7 +151,9 @@ class EfficientNet_B2_Weights(WeightsEnum):
class EfficientNet_B3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 12233232,
......@@ -158,7 +168,9 @@ class EfficientNet_B3_Weights(WeightsEnum):
class EfficientNet_B4_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 19341616,
......@@ -173,7 +185,9 @@ class EfficientNet_B4_Weights(WeightsEnum):
class EfficientNet_B5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 30389784,
......@@ -188,7 +202,9 @@ class EfficientNet_B5_Weights(WeightsEnum):
class EfficientNet_B6_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 43040704,
......@@ -203,7 +219,9 @@ class EfficientNet_B6_Weights(WeightsEnum):
class EfficientNet_B7_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC),
transforms=partial(
ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
),
meta={
**_COMMON_META_V1,
"num_params": 66347960,
......@@ -219,7 +237,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
transforms=partial(
ImageNetEval,
ImageClassificationEval,
crop_size=384,
resize_size=384,
interpolation=InterpolationMode.BILINEAR,
......@@ -239,7 +257,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
transforms=partial(
ImageNetEval,
ImageClassificationEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BILINEAR,
......@@ -259,7 +277,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial(
ImageNetEval,
ImageClassificationEval,
crop_size=480,
resize_size=480,
interpolation=InterpolationMode.BICUBIC,
......
......@@ -2,7 +2,7 @@ import warnings
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
......@@ -17,7 +17,7 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weig
class GoogLeNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
"task": "image_classification",
"architecture": "GoogLeNet",
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
......@@ -16,7 +16,7 @@ __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_
class Inception_V3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342),
transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342),
meta={
"task": "image_classification",
"architecture": "InceptionV3",
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet
......@@ -38,7 +38,7 @@ _COMMON_META = {
class MNASNet0_5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 2218512,
......@@ -57,7 +57,7 @@ class MNASNet0_75_Weights(WeightsEnum):
class MNASNet1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"num_params": 4383312,
......
from functools import partial
from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval
from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2
......@@ -28,7 +28,7 @@ _COMMON_META = {
class MobileNet_V2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224),
transforms=partial(ImageClassificationEval, crop_size=224),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
......@@ -38,7 +38,7 @@ class MobileNet_V2_Weights(WeightsEnum):
)
IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232),
transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={
**_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment