Unverified Commit b280c318 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Additional SOTA ingredients on Classification Recipe (#4493)

* Update EMA every X iters.

* Adding AdamW optimizer.

* Adjusting EMA decay scheme.

* Support custom weight decay for Normalization layers.

* Fix identation bug.

* Change EMA adjustment.

* Quality of life changes to faciliate testing

* ufmt format

* Fixing imports.

* Adding FixRes improvement.

* Support EMA in store_model_weights.

* Adding interpolation values.

* Change train_crop_size.

* Add interpolation option.

* Removing hardcoded interpolation and sizes from the scripts.

* Fixing linter.

* Incorporating feedback from code review.
parent f8468e72
...@@ -31,6 +31,17 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note ...@@ -31,6 +31,17 @@ Here `$MODEL` is one of `alexnet`, `vgg11`, `vgg13`, `vgg16` or `vgg19`. Note
that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch that `vgg11_bn`, `vgg13_bn`, `vgg16_bn`, and `vgg19_bn` include batch
normalization and thus are trained with the default parameters. normalization and thus are trained with the default parameters.
### Inception V3
The weights of the Inception V3 model are ported from the original paper rather than trained from scratch.
Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model use the following command:
```
torchrun --nproc_per_node=8 train.py --model inception_v3
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
```
### ResNext-50 32x4d ### ResNext-50 32x4d
``` ```
torchrun --nproc_per_node=8 train.py\ torchrun --nproc_per_node=8 train.py\
...@@ -79,6 +90,25 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht ...@@ -79,6 +90,25 @@ The weights of the B0-B4 variants are ported from Ross Wightman's [timm repo](ht
The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564). The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTorch repo](https://github.com/lukemelas/EfficientNet-PyTorch/blob/1039e009545d9329ea026c9f7541341439712b96/efficientnet_pytorch/utils.py#L562-L564).
All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands:
```
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\
--val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\
--val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\
--val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\
--val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\
--val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
```
### RegNet ### RegNet
...@@ -181,3 +211,8 @@ For post training quant, device is set to CPU. For training, the device is set t ...@@ -181,3 +211,8 @@ For post training quant, device is set to CPU. For training, the device is set t
``` ```
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>' python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
``` ```
For inception_v3 you need to pass the following extra parameters:
```
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299
```
...@@ -9,21 +9,22 @@ class ClassificationPresetTrain: ...@@ -9,21 +9,22 @@ class ClassificationPresetTrain:
crop_size, crop_size,
mean=(0.485, 0.456, 0.406), mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225), std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
hflip_prob=0.5, hflip_prob=0.5,
auto_augment_policy=None, auto_augment_policy=None,
random_erase_prob=0.0, random_erase_prob=0.0,
): ):
trans = [transforms.RandomResizedCrop(crop_size)] trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0: if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob)) trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None: if auto_augment_policy is not None:
if auto_augment_policy == "ra": if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment()) trans.append(autoaugment.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide": elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide()) trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
else: else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy)) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
trans.extend( trans.extend(
[ [
transforms.PILToTensor(), transforms.PILToTensor(),
......
...@@ -14,22 +14,20 @@ from torch.utils.data.dataloader import default_collate ...@@ -14,22 +14,20 @@ from torch.utils.data.dataloader import default_collate
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
def train_one_epoch( def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
header = "Epoch: [{}]".format(epoch) header = "Epoch: [{}]".format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header): for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time() start_time = time.time()
image, target = image.to(device), target.to(device) image, target = image.to(device), target.to(device)
output = model(image) output = model(image)
optimizer.zero_grad() optimizer.zero_grad()
if amp: if args.amp:
with torch.cuda.amp.autocast(): with torch.cuda.amp.autocast():
loss = criterion(output, target) loss = criterion(output, target)
scaler.scale(loss).backward() scaler.scale(loss).backward()
...@@ -40,6 +38,12 @@ def train_one_epoch( ...@@ -40,6 +38,12 @@ def train_one_epoch(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if model_ema and i % args.model_ema_steps == 0:
model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
model_ema.n_averaged.fill_(0)
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
batch_size = image.shape[0] batch_size = image.shape[0]
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
...@@ -47,9 +51,6 @@ def train_one_epoch( ...@@ -47,9 +51,6 @@ def train_one_epoch(
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time)) metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
if model_ema:
model_ema.update_parameters(model)
def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""): def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=""):
model.eval() model.eval()
...@@ -106,24 +107,8 @@ def _get_cache_path(filepath): ...@@ -106,24 +107,8 @@ def _get_cache_path(filepath):
def load_data(traindir, valdir, args): def load_data(traindir, valdir, args):
# Data loading code # Data loading code
print("Loading data") print("Loading data")
resize_size, crop_size = 256, 224 val_resize_size, val_crop_size, train_crop_size = args.val_resize_size, args.val_crop_size, args.train_crop_size
interpolation = InterpolationMode.BILINEAR interpolation = InterpolationMode(args.interpolation)
if args.model == "inception_v3":
resize_size, crop_size = 342, 299
elif args.model.startswith("efficientnet_"):
sizes = {
"b0": (256, 224),
"b1": (256, 240),
"b2": (288, 288),
"b3": (320, 300),
"b4": (384, 380),
"b5": (456, 456),
"b6": (528, 528),
"b7": (600, 600),
}
e_type = args.model.replace("efficientnet_", "")
resize_size, crop_size = sizes[e_type]
interpolation = InterpolationMode.BICUBIC
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
...@@ -138,7 +123,10 @@ def load_data(traindir, valdir, args): ...@@ -138,7 +123,10 @@ def load_data(traindir, valdir, args):
dataset = torchvision.datasets.ImageFolder( dataset = torchvision.datasets.ImageFolder(
traindir, traindir,
presets.ClassificationPresetTrain( presets.ClassificationPresetTrain(
crop_size=crop_size, auto_augment_policy=auto_augment_policy, random_erase_prob=random_erase_prob crop_size=train_crop_size,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
), ),
) )
if args.cache_dataset: if args.cache_dataset:
...@@ -156,7 +144,9 @@ def load_data(traindir, valdir, args): ...@@ -156,7 +144,9 @@ def load_data(traindir, valdir, args):
else: else:
dataset_test = torchvision.datasets.ImageFolder( dataset_test = torchvision.datasets.ImageFolder(
valdir, valdir,
presets.ClassificationPresetEval(crop_size=crop_size, resize_size=resize_size, interpolation=interpolation), presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
),
) )
if args.cache_dataset: if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path)) print("Saving dataset_test to {}".format(cache_path))
...@@ -224,10 +214,17 @@ def main(args): ...@@ -224,10 +214,17 @@ def main(args):
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing) criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
if args.norm_weight_decay is None:
parameters = model.parameters()
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
opt_name = args.opt.lower() opt_name = args.opt.lower()
if opt_name.startswith("sgd"): if opt_name.startswith("sgd"):
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
model.parameters(), parameters,
lr=args.lr, lr=args.lr,
momentum=args.momentum, momentum=args.momentum,
weight_decay=args.weight_decay, weight_decay=args.weight_decay,
...@@ -235,15 +232,12 @@ def main(args): ...@@ -235,15 +232,12 @@ def main(args):
) )
elif opt_name == "rmsprop": elif opt_name == "rmsprop":
optimizer = torch.optim.RMSprop( optimizer = torch.optim.RMSprop(
model.parameters(), parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=0.0316, alpha=0.9
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
eps=0.0316,
alpha=0.9,
) )
elif opt_name == "adamw":
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
else: else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
scaler = torch.cuda.amp.GradScaler() if args.amp else None scaler = torch.cuda.amp.GradScaler() if args.amp else None
...@@ -288,13 +282,23 @@ def main(args): ...@@ -288,13 +282,23 @@ def main(args):
model_ema = None model_ema = None
if args.model_ema: if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay) # Decay adjustment that aims to keep the decay independent from other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
#
# total_ema_updates = (Dataset_size / n_GPUs) * epochs / (batch_size_per_gpu * EMA_steps)
# We consider constant = Dataset_size for a given dataset/setup and ommit it. Thus:
# adjust = 1 / total_ema_updates ~= n_GPUs * batch_size_per_gpu * EMA_steps / epochs
adjust = args.world_size * args.batch_size * args.model_ema_steps / args.epochs
alpha = 1.0 - args.model_ema_decay
alpha = min(1.0, alpha * adjust)
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=1.0 - alpha)
if args.resume: if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu") checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"]) model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"]) if not args.test_only:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1 args.start_epoch = checkpoint["epoch"] + 1
if model_ema: if model_ema:
model_ema.load_state_dict(checkpoint["model_ema"]) model_ema.load_state_dict(checkpoint["model_ema"])
...@@ -303,8 +307,10 @@ def main(args): ...@@ -303,8 +307,10 @@ def main(args):
# We disable the cudnn benchmarking because it can noticeably affect the accuracy # We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
if model_ema:
evaluate(model, criterion, data_loader_test, device=device) evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix="EMA")
else:
evaluate(model, criterion, data_loader_test, device=device)
return return
print("Start training") print("Start training")
...@@ -312,9 +318,7 @@ def main(args): ...@@ -312,9 +318,7 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
train_one_epoch( train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema, scaler)
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
)
lr_scheduler.step() lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
if model_ema: if model_ema:
...@@ -362,6 +366,12 @@ def get_args_parser(add_help=True): ...@@ -362,6 +366,12 @@ def get_args_parser(add_help=True):
help="weight decay (default: 1e-4)", help="weight decay (default: 1e-4)",
dest="weight_decay", dest="weight_decay",
) )
parser.add_argument(
"--norm-weight-decay",
default=None,
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument( parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing" "--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
) )
...@@ -415,15 +425,33 @@ def get_args_parser(add_help=True): ...@@ -415,15 +425,33 @@ def get_args_parser(add_help=True):
parser.add_argument( parser.add_argument(
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
) )
parser.add_argument(
"--model-ema-steps",
type=int,
default=32,
help="the number of iterations that controls how often to update the EMA model (default: 32)",
)
parser.add_argument( parser.add_argument(
"--model-ema-decay", "--model-ema-decay",
type=float, type=float,
default=0.9, default=0.99998,
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", help="decay factor for Exponential Moving Average of model parameters (default: 0.99998)",
) )
parser.add_argument( parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only." "--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
) )
parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
return parser return parser
......
...@@ -236,6 +236,19 @@ def get_args_parser(add_help=True): ...@@ -236,6 +236,19 @@ def get_args_parser(add_help=True):
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
parser.add_argument(
"--interpolation", default="bilinear", type=str, help="the interpolation method (default: bilinear)"
)
parser.add_argument(
"--val-resize-size", default=256, type=int, help="the resize size used for validation (default: 256)"
)
parser.add_argument(
"--val-crop-size", default=224, type=int, help="the central crop size used for validation (default: 224)"
)
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
return parser return parser
......
...@@ -380,6 +380,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T ...@@ -380,6 +380,9 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
# Load the weights to the model to validate that everything works # Load the weights to the model to validate that everything works
# and remove unnecessary weights (such as auxiliaries, etc) # and remove unnecessary weights (such as auxiliaries, etc)
if checkpoint_key == "model_ema":
del checkpoint[checkpoint_key]["n_averaged"]
torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(checkpoint[checkpoint_key], "module.")
model.load_state_dict(checkpoint[checkpoint_key], strict=strict) model.load_state_dict(checkpoint[checkpoint_key], strict=strict)
tmp_path = os.path.join(output_dir, str(model.__hash__())) tmp_path = os.path.join(output_dir, str(model.__hash__()))
......
...@@ -9,10 +9,10 @@ import pytest ...@@ -9,10 +9,10 @@ import pytest
import torch import torch
from common_utils import needs_cuda, cpu_and_gpu, assert_equal from common_utils import needs_cuda, cpu_and_gpu, assert_equal
from PIL import Image from PIL import Image
from torch import Tensor from torch import nn, Tensor
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair from torch.nn.modules.utils import _pair
from torchvision import ops from torchvision import models, ops
class RoIOpTester(ABC): class RoIOpTester(ABC):
...@@ -1176,5 +1176,15 @@ class TestStochasticDepth: ...@@ -1176,5 +1176,15 @@ class TestStochasticDepth:
assert p_value > 0.0001 assert p_value > 0.0001
class TestUtils:
@pytest.mark.parametrize("norm_layer", [None, nn.BatchNorm2d, nn.LayerNorm])
def test_split_normalization_params(self, norm_layer):
model = models.mobilenet_v3_large(norm_layer=norm_layer)
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
assert len(params[0]) == 92
assert len(params[1]) == 82
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])
from typing import List, Union from typing import List, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import nn, Tensor
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor: def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
...@@ -36,3 +36,28 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): ...@@ -36,3 +36,28 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
else: else:
assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]" assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]"
return return
def split_normalization_params(
model: nn.Module, norm_classes: Optional[List[type]] = None
) -> Tuple[List[Tensor], List[Tensor]]:
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
if not norm_classes:
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
for t in norm_classes:
if not issubclass(t, nn.Module):
raise ValueError(f"Class {t} is not a subclass of nn.Module.")
classes = tuple(norm_classes)
norm_params = []
other_params = []
for module in model.modules():
if next(module.children(), None):
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
elif isinstance(module, classes):
norm_params.extend(p for p in module.parameters() if p.requires_grad)
else:
other_params.extend(p for p in module.parameters() if p.requires_grad)
return norm_params, other_params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment