Unverified Commit b280c318 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Additional SOTA ingredients on Classification Recipe (#4493)

* Update EMA every X iters.

* Adding AdamW optimizer.

* Adjusting EMA decay scheme.

* Support custom weight decay for Normalization layers.

* Fix identation bug.

* Change EMA adjustment.

* Quality of life changes to faciliate testing

* ufmt format

* Fixing imports.

* Adding FixRes improvement.

* Support EMA in store_model_weights.

* Adding interpolation values.

* Change train_crop_size.

* Add interpolation option.

* Removing hardcoded interpolation and sizes from the scripts.

* Fixing linter.

* Incorporating feedback from code review.
parent f8468e72
......@@ -31,6 +31,17 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
normalization and thus are trained with the default parameters.
### Inception V3
The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.
Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command:
```
torchrun --nproc_per_node=8 train.py --model inception_v3
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
```
### ResNext-50 32x4d
```
torchrun --nproc_per_node=8 train.py\
......@@ -79,6 +90,25 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht
The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).
All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands:
```
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\
--val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\
--val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\
--val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\
--val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\
--val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
```
### RegNet
......@@ -181,3 +211,8 @@ For post training quant, device is set to CPU. For training, the device is set t
```
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
```
For inception_v3 you need to pass the following extra parameters:
```
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299
```
......@@ -9,21 +9,22 @@ class ClassificationPresetTrain:
crop_size,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5,
auto_augment_policy=None,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size)]
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
trans.append(autoaugment.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide())
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
trans.extend(
[
transforms.PILToTensor(),
......
......@@ -14,22 +14,20 @@ from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode
def train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
header = "Epoch: [{}]".format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
output = model(image)
optimizer.zero_grad()
if amp:
if args.amp:
with torch.cuda.amp.autocast():
loss = criterion(output, target)
scaler.scale(loss).backward()
......@@ -40,6 +38,12 @@ def train_one_epoch(
loss.backward()
optimizer.step()
if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
......@@ -47,9 +51,6 @@ def train_one_epoch(
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
if model_ema:
model_ema.update_parameters(model)
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
model.eval()
......@@ -106,24 +107,8 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
resize_size, crop_size = 256, 224
interpolation = InterpolationMode.BILINEAR
if args.model == "inception_v3":
resize_size, crop_size = 342, 299
elif args.model.startswith("efficientnet_"):
sizes = {
"b0": (256, 224),
"b1": (256, 240),
"b2": (288, 288),
"b3": (320, 300),
"b4": (384, 380),
"b5": (456, 456),
"b6": (528, 528),
"b7": (600, 600),
}
e_type = args.model.replace("efficientnet_", "")
resize_size, crop_size = sizes[e_type]
interpolation = InterpolationMode.BICUBIC
val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
interpolation = InterpolationMode(args.interpolation)
print("Loading training data")
st = time.time()
......@@ -138,7 +123,10 @@ def load_data(traindir, valdir, args):
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob
crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
),
)
if args.cache_dataset:
......@@ -156,7 +144,9 @@ def load_data(traindir, valdir, args):
else:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation),
presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
),
)
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
......@@ -224,10 +214,17 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
if args.norm_weight_decay is None:
parameters = model.parameters()
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD(
model.parameters(),
parameters,
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
......@@ -235,15 +232,12 @@ def main(args):
)
elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
eps=0.0316,
alpha=0.9,
parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
)
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None
......@@ -288,13 +282,23 @@ def main(args):
model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)
# Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
if not args.test_only:
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"])
......@@ -303,8 +307,10 @@ def main(args):
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
......@@ -312,9 +318,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
......@@ -362,6 +366,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)",
dest="weight_decay",
)
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
)
......@@ -415,15 +425,33 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
)
parser.add_argument(
"--model-ema-steps",
type=int,
default=32,
help="the number of iterations that controls how often to update the EMA model (default: 32)",
)
parser.add_argument(
"--model-ema-decay",
type=float,
default=0.9,
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
default=0.99998,
help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
)
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
return parser
......
......@@ -236,6 +236,19 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
return parser
......
......@@ -380,6 +380,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
# Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc)
if checkpoint_key == "model_ema":
del checkpoint[checkpoint_key]["n_averaged"]
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
tmp_path = os.path.join(output_dir, str(model.__hash__()))
......
......@@ -9,10 +9,10 @@ import pytest
import torch
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from PIL import Image
from torch import Tensor
from torch import nn, Tensor
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import ops
from torchvision import models, ops
class RoIOpTester(ABC):
......@@ -1176,5 +1176,15 @@ class TestStochasticDepth:
assert p_value > 0.0001
class TestUtils:
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
def test_split_normalization_params(self, norm_layer):
model = models.mobilenet_v3_large(norm_layer=norm_layer)
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
assert len(params[0]) == 92
assert len(params[1]) == 82
if __name__ == "__main__":
pytest.main([__file__])
from typing import List, Union
from typing import List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch import nn, Tensor
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
......@@ -36,3 +36,28 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
else:
assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]"
return
def split_normalization_params(
model: nn.Module, norm_classes: Optional[List[type]] = None
) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
for t in norm_classes:
if not issubclass(t, nn.Module):
raise ValueError(f"Class {t} is not a subclass of nn.Module.")
classes = tuple(norm_classes)
norm_params = []
other_params = []
for module in model.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment