Unverified Commit 0ab7d05c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow classification references to use the tensor backend (#7629)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent a6f63879
...@@ -16,8 +16,16 @@ class ClassificationPresetTrain: ...@@ -16,8 +16,16 @@ class ClassificationPresetTrain:
ra_magnitude=9, ra_magnitude=9,
augmix_severity=3, augmix_severity=3,
random_erase_prob=0.0, random_erase_prob=0.0,
backend="pil",
): ):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] trans = []
backend = backend.lower()
if backend == "tensor":
trans.append(transforms.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0: if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None: if auto_augment_policy is not None:
...@@ -30,9 +38,12 @@ class ClassificationPresetTrain: ...@@ -30,9 +38,12 @@ class ClassificationPresetTrain:
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
if backend == "pil":
trans.append(transforms.PILToTensor())
trans.extend( trans.extend(
[ [
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float), transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std), transforms.Normalize(mean=mean, std=std),
] ]
...@@ -55,17 +66,30 @@ class ClassificationPresetEval: ...@@ -55,17 +66,30 @@ class ClassificationPresetEval:
mean=(0.485, 0.456, 0.406), mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225), std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR, interpolation=InterpolationMode.BILINEAR,
backend="pil",
): ):
trans = []
self.transforms = transforms.Compose( backend = backend.lower()
[ if backend == "tensor":
transforms.Resize(resize_size, interpolation=interpolation), trans.append(transforms.PILToTensor())
transforms.CenterCrop(crop_size), else:
transforms.PILToTensor(), raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std), trans += [
] transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
) transforms.CenterCrop(crop_size),
]
if backend == "pil":
trans.append(transforms.PILToTensor())
trans += [
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
self.transforms = transforms.Compose(trans)
def __call__(self, img): def __call__(self, img):
return self.transforms(img) return self.transforms(img)
...@@ -7,6 +7,7 @@ import presets ...@@ -7,6 +7,7 @@ import presets
import torch import torch
import torch.utils.data import torch.utils.data
import torchvision import torchvision
import torchvision.transforms
import transforms import transforms
import utils import utils
from sampler import RASampler from sampler import RASampler
...@@ -143,6 +144,7 @@ def load_data(traindir, valdir, args): ...@@ -143,6 +144,7 @@ def load_data(traindir, valdir, args):
random_erase_prob=random_erase_prob, random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude, ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity, augmix_severity=augmix_severity,
backend=args.backend,
), ),
) )
if args.cache_dataset: if args.cache_dataset:
...@@ -160,10 +162,16 @@ def load_data(traindir, valdir, args): ...@@ -160,10 +162,16 @@ def load_data(traindir, valdir, args):
else: else:
if args.weights and args.test_only: if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights) weights = torchvision.models.get_weight(args.weights)
preprocessing = weights.transforms() preprocessing = weights.transforms(antialias=True)
if args.backend == "tensor":
preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
else: else:
preprocessing = presets.ClassificationPresetEval( preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation crop_size=val_crop_size,
resize_size=val_resize_size,
interpolation=interpolation,
backend=args.backend,
) )
dataset_test = torchvision.datasets.ImageFolder( dataset_test = torchvision.datasets.ImageFolder(
...@@ -507,6 +515,7 @@ def get_args_parser(add_help=True): ...@@ -507,6 +515,7 @@ def get_args_parser(add_help=True):
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)" "--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
) )
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment