"server/vscode:/vscode.git/clone" did not exist on "87dc034b590723a7ebf354df576a13690d9664cc"
Unverified Commit 0ab7d05c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow classification references to use the tensor backend (#7629)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent a6f63879
......@@ -16,8 +16,16 @@ class ClassificationPresetTrain:
ra_magnitude=9,
augmix_severity=3,
random_erase_prob=0.0,
backend="pil",
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
trans = []
backend = backend.lower()
if backend == "tensor":
trans.append(transforms.PILToTensor())
elif backend != "pil":
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
......@@ -30,9 +38,12 @@ class ClassificationPresetTrain:
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
if backend == "pil":
trans.append(transforms.PILToTensor())
trans.extend(
[
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
......@@ -55,17 +66,30 @@ class ClassificationPresetEval:
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
backend="pil",
):
trans = []
self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
backend = backend.lower()
if backend == "tensor":
trans.append(transforms.PILToTensor())
else:
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
trans += [
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
transforms.CenterCrop(crop_size),
transforms.PILToTensor(),
]
if backend == "pil":
trans.append(transforms.PILToTensor())
trans += [
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
]
)
self.transforms = transforms.Compose(trans)
def __call__(self, img):
return self.transforms(img)
......@@ -7,6 +7,7 @@ import presets
import torch
import torch.utils.data
import torchvision
import torchvision.transforms
import transforms
import utils
from sampler import RASampler
......@@ -143,6 +144,7 @@ def load_data(traindir, valdir, args):
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
backend=args.backend,
),
)
if args.cache_dataset:
......@@ -160,10 +162,16 @@ def load_data(traindir, valdir, args):
else:
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
preprocessing = weights.transforms()
preprocessing = weights.transforms(antialias=True)
if args.backend == "tensor":
preprocessing = torchvision.transforms.Compose([torchvision.transforms.PILToTensor(), preprocessing])
else:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
crop_size=val_crop_size,
resize_size=val_resize_size,
interpolation=interpolation,
backend=args.backend,
)
dataset_test = torchvision.datasets.ImageFolder(
......@@ -507,6 +515,7 @@ def get_args_parser(add_help=True):
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment