Unverified Commit b7c23589 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Replaced ToTensor with a combination of PILToTensor and ConvertImageDtype (#4481)

* Replaced ToTensor with a combination of PILToTensor and ConvertImageDtype

* Pass dtype
parent 4ef8e6bc
import torch
from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode
......@@ -17,7 +18,8 @@ class ClassificationPresetTrain:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
])
if random_erase_prob > 0:
......@@ -36,7 +38,8 @@ class ClassificationPresetEval:
self.transforms = transforms.Compose([
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std),
])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment