Unverified Commit 08c9938f authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add --use-v2 support to classification references (#7724)

parent 23b0938f
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode
def get_module(use_v2):
# We need a protected import to avoid the V2 warning in case just V1 is used
if use_v2:
import torchvision.transforms.v2
return torchvision.transforms.v2
else:
import torchvision.transforms
return torchvision.transforms
class ClassificationPresetTrain:
def __init__(
self,
......@@ -17,41 +28,44 @@ class ClassificationPresetTrain:
augmix_severity=3,
random_erase_prob=0.0,
backend="pil",
use_v2=False,
):
trans = []
module = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
transforms.append(module.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
if backend == "pil":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())
trans.extend(
transforms.extend(
[
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
]
)
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))
transforms.append(module.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans)
self.transforms = module.Compose(transforms)
def __call__(self, img):
return self.transforms(img)
......@@ -67,28 +81,30 @@ class ClassificationPresetEval:
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
backend="pil",
use_v2=False,
):
trans = []
module = get_module(use_v2)
transforms = []
backend = backend.lower()
if backend == "tensor":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
trans += [
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
transforms += [
module.Resize(resize_size, interpolation=interpolation, antialias=True),
module.CenterCrop(crop_size),
]
if backend == "pil":
trans.append(transforms.PILToTensor())
transforms.append(module.PILToTensor())
trans += [
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
transforms += [
module.ConvertImageDtype(torch.float),
module.Normalize(mean=mean, std=std),
]
self.transforms = transforms.Compose(trans)
self.transforms = module.Compose(transforms)
def __call__(self, img):
return self.transforms(img)
......@@ -145,6 +145,7 @@ def load_data(traindir, valdir, args):
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
backend=args.backend,
use_v2=args.use_v2,
),
)
if args.cache_dataset:
......@@ -172,6 +173,7 @@ def load_data(traindir, valdir, args):
resize_size=val_resize_size,
interpolation=interpolation,
backend=args.backend,
use_v2=args.use_v2,
)
dataset_test = torchvision.datasets.ImageFolder(
......@@ -516,6 +518,7 @@ def get_args_parser(add_help=True):
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment