Unverified Commit 08c9938f authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add --use-v2 support to classification references (#7724)

parent 23b0938f
import torch import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
def get_module(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.transforms.v2
return torchvision.transforms.v2
else:
import torchvision.transforms
return torchvision.transforms
class ClassificationPresetTrain: class ClassificationPresetTrain:
def __init__( def __init__(
self, self,
...@@ -17,41 +28,44 @@ class ClassificationPresetTrain: ...@@ -17,41 +28,44 @@ class ClassificationPresetTrain:
augmix_severity=3, augmix_severity=3,
random_erase_prob=0.0, random_erase_prob=0.0,
backend="pil", backend="pil",
use_v2=False,
): ):
trans = [] module = get_module(use_v2)
transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "tensor": if backend == "tensor":
trans.append(transforms.PILToTensor()) transforms.append(module.PILToTensor())
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True)) transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0: if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) transforms.append(module.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None: if auto_augment_policy is not None:
if auto_augment_policy == "ra": if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude)) transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide": elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix": elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity)) transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
if backend == "pil": if backend == "pil":
trans.append(transforms.PILToTensor()) transforms.append(module.PILToTensor())
trans.extend( transforms.extend(
[ [
transforms.ConvertImageDtype(torch.float), module.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std), module.Normalize(mean=mean, std=std),
] ]
) )
if random_erase_prob > 0: if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob)) transforms.append(module.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans) self.transforms = module.Compose(transforms)
def __call__(self, img): def __call__(self, img):
return self.transforms(img) return self.transforms(img)
...@@ -67,28 +81,30 @@ class ClassificationPresetEval: ...@@ -67,28 +81,30 @@ class ClassificationPresetEval:
std=(0.229, 0.224, 0.225), std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR, interpolation=InterpolationMode.BILINEAR,
backend="pil", backend="pil",
use_v2=False,
): ):
trans = [] module = get_module(use_v2)
transforms = []
backend = backend.lower() backend = backend.lower()
if backend == "tensor": if backend == "tensor":
trans.append(transforms.PILToTensor()) transforms.append(module.PILToTensor())
elif backend != "pil": elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}") raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
trans += [ transforms += [
transforms.Resize(resize_size, interpolation=interpolation, antialias=True), module.Resize(resize_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size), module.CenterCrop(crop_size),
] ]
if backend == "pil": if backend == "pil":
trans.append(transforms.PILToTensor()) transforms.append(module.PILToTensor())
trans += [ transforms += [
transforms.ConvertImageDtype(torch.float), module.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std), module.Normalize(mean=mean, std=std),
] ]
self.transforms = transforms.Compose(trans) self.transforms = module.Compose(transforms)
def __call__(self, img): def __call__(self, img):
return self.transforms(img) return self.transforms(img)
...@@ -145,6 +145,7 @@ def load_data(traindir, valdir, args): ...@@ -145,6 +145,7 @@ def load_data(traindir, valdir, args):
ra_magnitude=ra_magnitude, ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity, augmix_severity=augmix_severity,
backend=args.backend, backend=args.backend,
use_v2=args.use_v2,
), ),
) )
if args.cache_dataset: if args.cache_dataset:
...@@ -172,6 +173,7 @@ def load_data(traindir, valdir, args): ...@@ -172,6 +173,7 @@ def load_data(traindir, valdir, args):
resize_size=val_resize_size, resize_size=val_resize_size,
interpolation=interpolation, interpolation=interpolation,
backend=args.backend, backend=args.backend,
use_v2=args.use_v2,
) )
dataset_test = torchvision.datasets.ImageFolder( dataset_test = torchvision.datasets.ImageFolder(
...@@ -516,6 +518,7 @@ def get_args_parser(add_help=True): ...@@ -516,6 +518,7 @@ def get_args_parser(add_help=True):
) )
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive") parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment