Unverified Commit d8654bb0 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor preset transforms (#5562)

* Refactor preset transforms

* Making presets public.
parent 2b5ab1bc
...@@ -163,7 +163,7 @@ def load_data(traindir, valdir, args): ...@@ -163,7 +163,7 @@ def load_data(traindir, valdir, args):
weights = prototype.models.get_weight(args.weights) weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms() preprocessing = weights.transforms()
else: else:
preprocessing = prototype.transforms.ImageNetEval( preprocessing = prototype.transforms.ImageClassificationEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
) )
......
...@@ -57,7 +57,7 @@ def get_transform(train, args): ...@@ -57,7 +57,7 @@ def get_transform(train, args):
weights = prototype.models.get_weight(args.weights) weights = prototype.models.get_weight(args.weights)
return weights.transforms() return weights.transforms()
else: else:
return prototype.transforms.CocoEval() return prototype.transforms.ObjectDetectionEval()
def get_args_parser(add_help=True): def get_args_parser(add_help=True):
......
...@@ -137,7 +137,7 @@ def validate(model, args): ...@@ -137,7 +137,7 @@ def validate(model, args):
weights = prototype.models.get_weight(args.weights) weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms() preprocessing = weights.transforms()
else: else:
preprocessing = prototype.transforms.RaftEval() preprocessing = prototype.transforms.OpticalFlowEval()
else: else:
preprocessing = OpticalFlowPresetEval() preprocessing = OpticalFlowPresetEval()
......
...@@ -42,7 +42,7 @@ def get_transform(train, args): ...@@ -42,7 +42,7 @@ def get_transform(train, args):
weights = prototype.models.get_weight(args.weights) weights = prototype.models.get_weight(args.weights)
return weights.transforms() return weights.transforms()
else: else:
return prototype.transforms.VocEval(resize_size=520) return prototype.transforms.SemanticSegmentationEval(resize_size=520)
def criterion(inputs, target): def criterion(inputs, target):
......
...@@ -157,7 +157,7 @@ def main(args): ...@@ -157,7 +157,7 @@ def main(args):
weights = prototype.models.get_weight(args.weights) weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms() transform_test = weights.transforms()
else: else:
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171)) transform_test = prototype.transforms.VideoClassificationEval(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}") print(f"Loading dataset_test from {cache_path}")
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.alexnet import AlexNet from ...models.alexnet import AlexNet
...@@ -16,7 +16,7 @@ __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"] ...@@ -16,7 +16,7 @@ __all__ = ["AlexNet", "AlexNet_Weights", "alexnet"]
class AlexNet_Weights(WeightsEnum): class AlexNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", url="https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
"task": "image_classification", "task": "image_classification",
"architecture": "AlexNet", "architecture": "AlexNet",
......
from functools import partial from functools import partial
from typing import Any, List, Optional from typing import Any, List, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.convnext import ConvNeXt, CNBlockConfig from ...models.convnext import ConvNeXt, CNBlockConfig
...@@ -56,7 +56,7 @@ _COMMON_META = { ...@@ -56,7 +56,7 @@ _COMMON_META = {
class ConvNeXt_Tiny_Weights(WeightsEnum): class ConvNeXt_Tiny_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth", url="https://download.pytorch.org/models/convnext_tiny-983f1562.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=236), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=236),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 28589128, "num_params": 28589128,
...@@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): ...@@ -70,7 +70,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum):
class ConvNeXt_Small_Weights(WeightsEnum): class ConvNeXt_Small_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_small-0c510722.pth", url="https://download.pytorch.org/models/convnext_small-0c510722.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=230), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=230),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 50223688, "num_params": 50223688,
...@@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum): ...@@ -84,7 +84,7 @@ class ConvNeXt_Small_Weights(WeightsEnum):
class ConvNeXt_Base_Weights(WeightsEnum): class ConvNeXt_Base_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_base-6075fbad.pth", url="https://download.pytorch.org/models/convnext_base-6075fbad.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 88591464, "num_params": 88591464,
...@@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum): ...@@ -98,7 +98,7 @@ class ConvNeXt_Base_Weights(WeightsEnum):
class ConvNeXt_Large_Weights(WeightsEnum): class ConvNeXt_Large_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/convnext_large-ea097f82.pth", url="https://download.pytorch.org/models/convnext_large-ea097f82.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 197767336, "num_params": 197767336,
......
...@@ -3,7 +3,7 @@ from functools import partial ...@@ -3,7 +3,7 @@ from functools import partial
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.densenet import DenseNet from ...models.densenet import DenseNet
...@@ -78,7 +78,7 @@ _COMMON_META = { ...@@ -78,7 +78,7 @@ _COMMON_META = {
class DenseNet121_Weights(WeightsEnum): class DenseNet121_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet121-a639ec97.pth", url="https://download.pytorch.org/models/densenet121-a639ec97.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 7978856, "num_params": 7978856,
...@@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum): ...@@ -92,7 +92,7 @@ class DenseNet121_Weights(WeightsEnum):
class DenseNet161_Weights(WeightsEnum): class DenseNet161_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet161-8d451a50.pth", url="https://download.pytorch.org/models/densenet161-8d451a50.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 28681000, "num_params": 28681000,
...@@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum): ...@@ -106,7 +106,7 @@ class DenseNet161_Weights(WeightsEnum):
class DenseNet169_Weights(WeightsEnum): class DenseNet169_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet169-b2777c0a.pth", url="https://download.pytorch.org/models/densenet169-b2777c0a.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 14149480, "num_params": 14149480,
...@@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum): ...@@ -120,7 +120,7 @@ class DenseNet169_Weights(WeightsEnum):
class DenseNet201_Weights(WeightsEnum): class DenseNet201_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/densenet201-c1103571.pth", url="https://download.pytorch.org/models/densenet201-c1103571.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 20013928, "num_params": 20013928,
......
from typing import Any, Optional, Union from typing import Any, Optional, Union
from torch import nn from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.faster_rcnn import ( from ....models.detection.faster_rcnn import (
...@@ -43,7 +43,7 @@ _COMMON_META = { ...@@ -43,7 +43,7 @@ _COMMON_META = {
class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 41755286, "num_params": 41755286,
...@@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -57,7 +57,7 @@ class FasterRCNN_ResNet50_FPN_Weights(WeightsEnum):
class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 19386354, "num_params": 19386354,
...@@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum): ...@@ -71,7 +71,7 @@ class FasterRCNN_MobileNet_V3_Large_FPN_Weights(WeightsEnum):
class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum): class FasterRCNN_MobileNet_V3_Large_320_FPN_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 19386354, "num_params": 19386354,
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.fcos import ( from ....models.detection.fcos import (
...@@ -27,7 +27,7 @@ __all__ = [ ...@@ -27,7 +27,7 @@ __all__ = [
class FCOS_ResNet50_FPN_Weights(WeightsEnum): class FCOS_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
"task": "image_object_detection", "task": "image_object_detection",
"architecture": "FCOS", "architecture": "FCOS",
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.keypoint_rcnn import ( from ....models.detection.keypoint_rcnn import (
...@@ -37,7 +37,7 @@ _COMMON_META = { ...@@ -37,7 +37,7 @@ _COMMON_META = {
class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_LEGACY = Weights( COCO_LEGACY = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-9f466800.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 59137258, "num_params": 59137258,
...@@ -48,7 +48,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum): ...@@ -48,7 +48,7 @@ class KeypointRCNN_ResNet50_FPN_Weights(WeightsEnum):
) )
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth", url="https://download.pytorch.org/models/keypointrcnn_resnet50_fpn_coco-fc266e95.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 59137258, "num_params": 59137258,
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.mask_rcnn import ( from ....models.detection.mask_rcnn import (
...@@ -27,7 +27,7 @@ __all__ = [ ...@@ -27,7 +27,7 @@ __all__ = [
class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum): class MaskRCNN_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth", url="https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
"task": "image_object_detection", "task": "image_object_detection",
"architecture": "MaskRCNN", "architecture": "MaskRCNN",
......
from typing import Any, Optional from typing import Any, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.retinanet import ( from ....models.detection.retinanet import (
...@@ -28,7 +28,7 @@ __all__ = [ ...@@ -28,7 +28,7 @@ __all__ = [
class RetinaNet_ResNet50_FPN_Weights(WeightsEnum): class RetinaNet_ResNet50_FPN_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth", url="https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
"task": "image_object_detection", "task": "image_object_detection",
"architecture": "RetinaNet", "architecture": "RetinaNet",
......
import warnings import warnings
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssd import ( from ....models.detection.ssd import (
...@@ -25,7 +25,7 @@ __all__ = [ ...@@ -25,7 +25,7 @@ __all__ = [
class SSD300_VGG16_Weights(WeightsEnum): class SSD300_VGG16_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth", url="https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
"task": "image_object_detection", "task": "image_object_detection",
"architecture": "SSD", "architecture": "SSD",
......
...@@ -3,7 +3,7 @@ from functools import partial ...@@ -3,7 +3,7 @@ from functools import partial
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from torch import nn from torch import nn
from torchvision.prototype.transforms import CocoEval from torchvision.prototype.transforms import ObjectDetectionEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ....models.detection.ssdlite import ( from ....models.detection.ssdlite import (
...@@ -30,7 +30,7 @@ __all__ = [ ...@@ -30,7 +30,7 @@ __all__ = [
class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum): class SSDLite320_MobileNet_V3_Large_Weights(WeightsEnum):
COCO_V1 = Weights( COCO_V1 = Weights(
url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth", url="https://download.pytorch.org/models/ssdlite320_mobilenet_v3_large_coco-a79551df.pth",
transforms=CocoEval, transforms=ObjectDetectionEval,
meta={ meta={
"task": "image_object_detection", "task": "image_object_detection",
"architecture": "SSDLite", "architecture": "SSDLite",
......
...@@ -2,7 +2,7 @@ from functools import partial ...@@ -2,7 +2,7 @@ from functools import partial
from typing import Any, Optional, Sequence, Union from typing import Any, Optional, Sequence, Union
from torch import nn from torch import nn
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf from ...models.efficientnet import EfficientNet, MBConvConfig, FusedMBConvConfig, _efficientnet_conf
...@@ -85,7 +85,9 @@ _COMMON_META_V2 = { ...@@ -85,7 +85,9 @@ _COMMON_META_V2 = {
class EfficientNet_B0_Weights(WeightsEnum): class EfficientNet_B0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 5288548, "num_params": 5288548,
...@@ -100,7 +102,9 @@ class EfficientNet_B0_Weights(WeightsEnum): ...@@ -100,7 +102,9 @@ class EfficientNet_B0_Weights(WeightsEnum):
class EfficientNet_B1_Weights(WeightsEnum): class EfficientNet_B1_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 7794184, "num_params": 7794184,
...@@ -111,7 +115,9 @@ class EfficientNet_B1_Weights(WeightsEnum): ...@@ -111,7 +115,9 @@ class EfficientNet_B1_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth",
transforms=partial(ImageNetEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR), transforms=partial(
ImageClassificationEval, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 7794184, "num_params": 7794184,
...@@ -128,7 +134,9 @@ class EfficientNet_B1_Weights(WeightsEnum): ...@@ -128,7 +134,9 @@ class EfficientNet_B1_Weights(WeightsEnum):
class EfficientNet_B2_Weights(WeightsEnum): class EfficientNet_B2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
transforms=partial(ImageNetEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 9109994, "num_params": 9109994,
...@@ -143,7 +151,9 @@ class EfficientNet_B2_Weights(WeightsEnum): ...@@ -143,7 +151,9 @@ class EfficientNet_B2_Weights(WeightsEnum):
class EfficientNet_B3_Weights(WeightsEnum): class EfficientNet_B3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
transforms=partial(ImageNetEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 12233232, "num_params": 12233232,
...@@ -158,7 +168,9 @@ class EfficientNet_B3_Weights(WeightsEnum): ...@@ -158,7 +168,9 @@ class EfficientNet_B3_Weights(WeightsEnum):
class EfficientNet_B4_Weights(WeightsEnum): class EfficientNet_B4_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
transforms=partial(ImageNetEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 19341616, "num_params": 19341616,
...@@ -173,7 +185,9 @@ class EfficientNet_B4_Weights(WeightsEnum): ...@@ -173,7 +185,9 @@ class EfficientNet_B4_Weights(WeightsEnum):
class EfficientNet_B5_Weights(WeightsEnum): class EfficientNet_B5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
transforms=partial(ImageNetEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 30389784, "num_params": 30389784,
...@@ -188,7 +202,9 @@ class EfficientNet_B5_Weights(WeightsEnum): ...@@ -188,7 +202,9 @@ class EfficientNet_B5_Weights(WeightsEnum):
class EfficientNet_B6_Weights(WeightsEnum): class EfficientNet_B6_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
transforms=partial(ImageNetEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 43040704, "num_params": 43040704,
...@@ -203,7 +219,9 @@ class EfficientNet_B6_Weights(WeightsEnum): ...@@ -203,7 +219,9 @@ class EfficientNet_B6_Weights(WeightsEnum):
class EfficientNet_B7_Weights(WeightsEnum): class EfficientNet_B7_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
transforms=partial(ImageNetEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC), transforms=partial(
ImageClassificationEval, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC
),
meta={ meta={
**_COMMON_META_V1, **_COMMON_META_V1,
"num_params": 66347960, "num_params": 66347960,
...@@ -219,7 +237,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum): ...@@ -219,7 +237,7 @@ class EfficientNet_V2_S_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth",
transforms=partial( transforms=partial(
ImageNetEval, ImageClassificationEval,
crop_size=384, crop_size=384,
resize_size=384, resize_size=384,
interpolation=InterpolationMode.BILINEAR, interpolation=InterpolationMode.BILINEAR,
...@@ -239,7 +257,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum): ...@@ -239,7 +257,7 @@ class EfficientNet_V2_M_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth",
transforms=partial( transforms=partial(
ImageNetEval, ImageClassificationEval,
crop_size=480, crop_size=480,
resize_size=480, resize_size=480,
interpolation=InterpolationMode.BILINEAR, interpolation=InterpolationMode.BILINEAR,
...@@ -259,7 +277,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum): ...@@ -259,7 +277,7 @@ class EfficientNet_V2_L_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth",
transforms=partial( transforms=partial(
ImageNetEval, ImageClassificationEval,
crop_size=480, crop_size=480,
resize_size=480, resize_size=480,
interpolation=InterpolationMode.BICUBIC, interpolation=InterpolationMode.BICUBIC,
......
...@@ -2,7 +2,7 @@ import warnings ...@@ -2,7 +2,7 @@ import warnings
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs from ...models.googlenet import GoogLeNet, GoogLeNetOutputs, _GoogLeNetOutputs
...@@ -17,7 +17,7 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weig ...@@ -17,7 +17,7 @@ __all__ = ["GoogLeNet", "GoogLeNetOutputs", "_GoogLeNetOutputs", "GoogLeNet_Weig
class GoogLeNet_Weights(WeightsEnum): class GoogLeNet_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/googlenet-1378be20.pth", url="https://download.pytorch.org/models/googlenet-1378be20.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
"task": "image_classification", "task": "image_classification",
"architecture": "GoogLeNet", "architecture": "GoogLeNet",
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs from ...models.inception import Inception3, InceptionOutputs, _InceptionOutputs
...@@ -16,7 +16,7 @@ __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_ ...@@ -16,7 +16,7 @@ __all__ = ["Inception3", "InceptionOutputs", "_InceptionOutputs", "Inception_V3_
class Inception_V3_Weights(WeightsEnum): class Inception_V3_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", url="https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
transforms=partial(ImageNetEval, crop_size=299, resize_size=342), transforms=partial(ImageClassificationEval, crop_size=299, resize_size=342),
meta={ meta={
"task": "image_classification", "task": "image_classification",
"architecture": "InceptionV3", "architecture": "InceptionV3",
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mnasnet import MNASNet from ...models.mnasnet import MNASNet
...@@ -38,7 +38,7 @@ _COMMON_META = { ...@@ -38,7 +38,7 @@ _COMMON_META = {
class MNASNet0_5_Weights(WeightsEnum): class MNASNet0_5_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth", url="https://download.pytorch.org/models/mnasnet0.5_top1_67.823-3ffadce67e.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 2218512, "num_params": 2218512,
...@@ -57,7 +57,7 @@ class MNASNet0_75_Weights(WeightsEnum): ...@@ -57,7 +57,7 @@ class MNASNet0_75_Weights(WeightsEnum):
class MNASNet1_0_Weights(WeightsEnum): class MNASNet1_0_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth", url="https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"num_params": 4383312, "num_params": 4383312,
......
from functools import partial from functools import partial
from typing import Any, Optional from typing import Any, Optional
from torchvision.prototype.transforms import ImageNetEval from torchvision.prototype.transforms import ImageClassificationEval
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from ...models.mobilenetv2 import MobileNetV2 from ...models.mobilenetv2 import MobileNetV2
...@@ -28,7 +28,7 @@ _COMMON_META = { ...@@ -28,7 +28,7 @@ _COMMON_META = {
class MobileNet_V2_Weights(WeightsEnum): class MobileNet_V2_Weights(WeightsEnum):
IMAGENET1K_V1 = Weights( IMAGENET1K_V1 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", url="https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
transforms=partial(ImageNetEval, crop_size=224), transforms=partial(ImageClassificationEval, crop_size=224),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2", "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv2",
...@@ -38,7 +38,7 @@ class MobileNet_V2_Weights(WeightsEnum): ...@@ -38,7 +38,7 @@ class MobileNet_V2_Weights(WeightsEnum):
) )
IMAGENET1K_V2 = Weights( IMAGENET1K_V2 = Weights(
url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth", url="https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth",
transforms=partial(ImageNetEval, crop_size=224, resize_size=232), transforms=partial(ImageClassificationEval, crop_size=224, resize_size=232),
meta={ meta={
**_COMMON_META, **_COMMON_META,
"recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning", "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-reg-tuning",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment