Unverified Commit 1703e4ca authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Adding Preset Transforms in reference scripts (#3317)

* Adding presets in the classification reference scripts.

* Adding presets in the object detection reference scripts.

* Adding presets in the segmentation reference scripts.

* Adding presets in the video classification reference scripts.

* Moving flip at the end to align with image classification signature.
parent 7621a8ed
...@@ -124,22 +124,9 @@ Training converges at about 10 epochs. ...@@ -124,22 +124,9 @@ Training converges at about 10 epochs.
For post training quant, device is set to CPU. For training, the device is set to CUDA For post training quant, device is set to CPU. For training, the device is set to CUDA
### Command to evaluate quantized models using the pre-trained weights: ### Command to evaluate quantized models using the pre-trained weights:
For all quantized models except inception_v3: For all quantized models:
``` ```
python references/classification/train_quantization.py --data-path='imagenet_full_size/' \ python references/classification/train_quantization.py --data-path='imagenet_full_size/' \
--device='cpu' --test-only --backend='fbgemm' --model='<model_name>' --device='cpu' --test-only --backend='fbgemm' --model='<model_name>'
``` ```
For inception_v3, since it expects tensors with a size of N x 3 x 299 x 299, before running above command,
need to change the input size of dataset_test in train.py to:
```
dataset_test = torchvision.datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(342),
transforms.CenterCrop(299),
transforms.ToTensor(),
normalize,
]))
```
from torchvision.transforms import autoaugment, transforms
class ClassificationPresetTrain:
def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), hflip_prob=0.5,
auto_augment_policy=None, random_erase_prob=0.0):
trans = [transforms.RandomResizedCrop(crop_size)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
if random_erase_prob > 0:
trans.append(transforms.RandomErasing(p=random_erase_prob))
self.transforms = transforms.Compose(trans)
def __call__(self, img):
return self.transforms(img)
class ClassificationPresetEval:
def __init__(self, crop_size, resize_size=256, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = transforms.Compose([
transforms.Resize(resize_size),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
def __call__(self, img):
return self.transforms(img)
...@@ -6,8 +6,8 @@ import torch ...@@ -6,8 +6,8 @@ import torch
import torch.utils.data import torch.utils.data
from torch import nn from torch import nn
import torchvision import torchvision
from torchvision import transforms
import presets
import utils import utils
try: try:
...@@ -82,8 +82,7 @@ def _get_cache_path(filepath): ...@@ -82,8 +82,7 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args): def load_data(traindir, valdir, args):
# Data loading code # Data loading code
print("Loading data") print("Loading data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], resize_size, crop_size = (342, 299) if args.model == 'inception_v3' else (256, 224)
std=[0.229, 0.224, 0.225])
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
...@@ -93,22 +92,10 @@ def load_data(traindir, valdir, args): ...@@ -93,22 +92,10 @@ def load_data(traindir, valdir, args):
print("Loading dataset_train from {}".format(cache_path)) print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path) dataset, _ = torch.load(cache_path)
else: else:
trans = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
if args.auto_augment is not None:
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
trans.append(transforms.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
normalize,
])
if args.random_erase > 0:
trans.append(transforms.RandomErasing(p=args.random_erase))
dataset = torchvision.datasets.ImageFolder( dataset = torchvision.datasets.ImageFolder(
traindir, traindir,
transforms.Compose(trans)) presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment,
random_erase_prob=args.random_erase))
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path)) print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
...@@ -124,12 +111,7 @@ def load_data(traindir, valdir, args): ...@@ -124,12 +111,7 @@ def load_data(traindir, valdir, args):
else: else:
dataset_test = torchvision.datasets.ImageFolder( dataset_test = torchvision.datasets.ImageFolder(
valdir, valdir,
transforms.Compose([ presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size))
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path)) print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path)) utils.mkdir(os.path.dirname(cache_path))
......
import transforms as T
class DetectionPresetTrain:
def __init__(self, hflip_prob=0.5):
trans = [T.ToTensor()]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
self.transforms = T.Compose(trans)
def __call__(self, img, target):
return self.transforms(img, target)
class DetectionPresetEval:
def __init__(self):
self.transforms = T.ToTensor()
def __call__(self, img, target):
return self.transforms(img, target)
...@@ -32,8 +32,8 @@ from coco_utils import get_coco, get_coco_kp ...@@ -32,8 +32,8 @@ from coco_utils import get_coco, get_coco_kp
from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
from engine import train_one_epoch, evaluate from engine import train_one_epoch, evaluate
import presets
import utils import utils
import transforms as T
def get_dataset(name, image_set, transform, data_path): def get_dataset(name, image_set, transform, data_path):
...@@ -48,11 +48,7 @@ def get_dataset(name, image_set, transform, data_path): ...@@ -48,11 +48,7 @@ def get_dataset(name, image_set, transform, data_path):
def get_transform(train): def get_transform(train):
transforms = [] return presets.DetectionPresetTrain() if train else presets.DetectionPresetEval()
transforms.append(T.ToTensor())
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
return T.Compose(transforms)
def main(args): def main(args):
......
import transforms as T
class SegmentationPresetTrain:
def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
min_size = int(0.5 * base_size)
max_size = int(2.0 * base_size)
trans = [T.RandomResize(min_size, max_size)]
if hflip_prob > 0:
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend([
T.RandomCrop(crop_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
self.transforms = T.Compose(trans)
def __call__(self, img, target):
return self.transforms(img, target)
class SegmentationPresetEval:
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.RandomResize(base_size, base_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std),
])
def __call__(self, img, target):
return self.transforms(img, target)
...@@ -8,7 +8,7 @@ from torch import nn ...@@ -8,7 +8,7 @@ from torch import nn
import torchvision import torchvision
from coco_utils import get_coco from coco_utils import get_coco
import transforms as T import presets
import utils import utils
...@@ -30,18 +30,7 @@ def get_transform(train): ...@@ -30,18 +30,7 @@ def get_transform(train):
base_size = 520 base_size = 520
crop_size = 480 crop_size = 480
min_size = int((0.5 if train else 1.0) * base_size) return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)
max_size = int((2.0 if train else 1.0) * base_size)
transforms = []
transforms.append(T.RandomResize(min_size, max_size))
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.RandomCrop(crop_size))
transforms.append(T.ToTensor())
transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]))
return T.Compose(transforms)
def criterion(inputs, target): def criterion(inputs, target):
......
import torch
from torchvision.transforms import transforms
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
class VideoClassificationPresetTrain:
def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989),
hflip_prob=0.5):
trans = [
ConvertBHWCtoBCHW(),
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
trans.extend([
transforms.Normalize(mean=mean, std=std),
transforms.RandomCrop(crop_size),
ConvertBCHWtoCBHW()
])
self.transforms = transforms.Compose(trans)
def __call__(self, x):
return self.transforms(x)
class VideoClassificationPresetEval:
def __init__(self, resize_size, crop_size, mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)):
self.transforms = transforms.Compose([
ConvertBHWCtoBCHW(),
transforms.ConvertImageDtype(torch.float32),
transforms.Resize(resize_size),
transforms.Normalize(mean=mean, std=std),
transforms.CenterCrop(crop_size),
ConvertBCHWtoCBHW()
])
def __call__(self, x):
return self.transforms(x)
...@@ -7,13 +7,12 @@ from torch.utils.data.dataloader import default_collate ...@@ -7,13 +7,12 @@ from torch.utils.data.dataloader import default_collate
from torch import nn from torch import nn
import torchvision import torchvision
import torchvision.datasets.video_utils import torchvision.datasets.video_utils
from torchvision import transforms as T
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
import presets
import utils import utils
from scheduler import WarmupMultiStepLR from scheduler import WarmupMultiStepLR
from transforms import ConvertBHWCtoBCHW, ConvertBCHWtoCBHW
try: try:
from apex import amp from apex import amp
...@@ -112,21 +111,11 @@ def main(args): ...@@ -112,21 +111,11 @@ def main(args):
print("Loading data") print("Loading data")
traindir = os.path.join(args.data_path, args.train_dir) traindir = os.path.join(args.data_path, args.train_dir)
valdir = os.path.join(args.data_path, args.val_dir) valdir = os.path.join(args.data_path, args.val_dir)
normalize = T.Normalize(mean=[0.43216, 0.394666, 0.37645],
std=[0.22803, 0.22145, 0.216989])
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
cache_path = _get_cache_path(traindir) cache_path = _get_cache_path(traindir)
transform_train = torchvision.transforms.Compose([ transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112))
ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)),
T.RandomHorizontalFlip(),
normalize,
T.RandomCrop((112, 112)),
ConvertBCHWtoCBHW()
])
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_train from {}".format(cache_path)) print("Loading dataset_train from {}".format(cache_path))
...@@ -154,14 +143,7 @@ def main(args): ...@@ -154,14 +143,7 @@ def main(args):
print("Loading validation data") print("Loading validation data")
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir)
transform_test = torchvision.transforms.Compose([ transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
ConvertBHWCtoBCHW(),
T.ConvertImageDtype(torch.float32),
T.Resize((128, 171)),
normalize,
T.CenterCrop((112, 112)),
ConvertBCHWtoCBHW()
])
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
print("Loading dataset_test from {}".format(cache_path)) print("Loading dataset_test from {}".format(cache_path))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment