Unverified Commit b7c23589 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Replaced ToTensor with a combination of PILToTensor and ConvertImageDtype (#4481)

* Replaced ToTensor with a combination of PILToTensor and ConvertImageDtype

* Pass dtype
parent 4ef8e6bc
import torch
from torchvision.transforms import autoaugment, transforms from torchvision.transforms import autoaugment, transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
...@@ -17,7 +18,8 @@ class ClassificationPresetTrain: ...@@ -17,7 +18,8 @@ class ClassificationPresetTrain:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([ trans.extend([
transforms.ToTensor(), transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std), transforms.Normalize(mean=mean, std=std),
]) ])
if random_erase_prob > 0: if random_erase_prob > 0:
...@@ -36,7 +38,8 @@ class ClassificationPresetEval: ...@@ -36,7 +38,8 @@ class ClassificationPresetEval:
self.transforms = transforms.Compose([ self.transforms = transforms.Compose([
transforms.Resize(resize_size, interpolation=interpolation), transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size), transforms.CenterCrop(crop_size),
transforms.ToTensor(), transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Normalize(mean=mean, std=std), transforms.Normalize(mean=mean, std=std),
]) ])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment