Commit a1c29028 authored by zhangqha's avatar zhangqha
Browse files

update uni-fold

parents
Pipeline #183 canceled with stages
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import torch
from unicore import optim
from unicore import utils
from .dynamic_loss_scaler import DynamicLossScaler
def check_param_device(params):
if len(params) <= 0:
return True
device = params[0].device
for i in range(1, len(params)):
assert device == params[i].device
def pad_numel(numel, multiplier=2):
return (numel + multiplier - 1) // multiplier * multiplier
class _FP16OptimizerMixin(object):
def __init__(self, args, **kwargs):
# forward __init__ call to the next class in mro(method resolution order)
super().__init__(args, **kwargs)
self._multiply_factor = 1.0
self.bf16_sr = getattr(args, "bf16_sr", False)
@classmethod
def build_fp32_params(cls, args, params):
# create FP32 copy of parameters and grads
total_param_size = sum([p.data.numel() for p in params])
fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
fp32_params[offset : offset + numel].copy_(p.data.view(-1))
offset += numel
fp32_params = torch.nn.Parameter(fp32_params)
fp32_params.grad = fp32_params.data.new(total_param_size)
return fp32_params
@classmethod
def flatten_fp16_parameters(cls, args, params):
dtype_grouped_params = {}
for p in params:
if p.dtype not in dtype_grouped_params:
dtype_grouped_params[p.dtype] = []
dtype_grouped_params[p.dtype].append(p)
flatten_params = {}
for dtype in dtype_grouped_params:
cur_params = dtype_grouped_params[dtype]
total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params)
flatten_params[dtype] = (
cur_params[0].new(0).type(dtype).new(total_param_size)
)
offset = 0
for p in cur_params:
numel = p.data.numel()
flatten_params[dtype][offset : offset + numel].copy_(p.data.view(-1))
p.data = (
flatten_params[dtype].data[offset : offset + numel].view(*p.shape)
)
offset += pad_numel(numel)
flatten_params[dtype] = torch.nn.Parameter(flatten_params[dtype])
flatten_params[dtype].grad = flatten_params[dtype].data.new(
total_param_size
)
offset = 0
for p in cur_params:
numel = p.data.numel()
p.grad = (
flatten_params[dtype].grad[offset : offset + numel].view(*p.shape)
)
offset += pad_numel(numel)
torch.cuda.empty_cache()
return list(flatten_params.values())
def state_dict(self):
"""Return the optimizer's state dict."""
state_dict = self.fp32_optimizer.state_dict()
if self.scaler is not None:
state_dict["loss_scale"] = self.scaler.loss_scale
return state_dict
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
if "loss_scale" in state_dict and self.scaler is not None:
self.scaler.loss_scale = state_dict["loss_scale"]
self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves.
Compared to :func:`unicore.optim.UnicoreOptimizer.backward`, this
function additionally dynamically scales the loss to avoid gradient
underflow.
"""
if self.scaler is not None:
loss = self.scaler.scale(loss)
loss.backward()
self._needs_sync = True
def _sync_fp16_grads_to_fp32(self):
with torch.no_grad():
if self._needs_sync:
offset = 0
for p in self.fp16_params:
numel = p.numel()
self.fp32_params.grad.data[offset : offset + numel].copy_(
p.grad.data.view(-1)
)
offset += pad_numel(numel)
self._needs_sync = False
def _add_fp16_grads_to_fp32(self, mul=0.0):
with torch.no_grad():
offset = 0
for p in self.fp16_params:
numel = p.numel()
self.fp32_params.grad.data[
offset : offset + numel
] += mul * p.grad.data.float().view(-1)
p.grad.zero_()
offset += pad_numel(numel)
self._needs_sync = False
def _sync_fp32_params_to_fp16(self):
# copy FP32 params back into FP16 model
offset = 0
for p in self.fp16_params:
numel = p.numel()
u = self.fp32_params.data[offset : offset + numel].view_as(p.data)
if self.bf16_sr and p.dtype == torch.bfloat16:
utils.fp32_to_bf16_sr(u, p)
else:
p.data.copy_(u)
offset += pad_numel(numel)
def _unscale_grads(self):
self._sync_fp16_grads_to_fp32()
if (
# Skip the multiplication if it's a no-op (i.e., if _multiply_factor
# is 1.0). At the same time, we want to avoid the device-to-host
# transfer by comparing it to 1.0. Since _multiply_factor starts as
# a Python float, we roughly assume that if it's a tensor then it's
# probably not =1.0 anymore and we do the multiplication. Otherwise
# we can safely check the value without a D2H transfer.
torch.is_tensor(self._multiply_factor)
or self._multiply_factor != 1.0
):
self.fp32_optimizer.multiply_grads(self._multiply_factor)
self._multiply_factor = 1.0
def multiply_grads(self, c):
"""Multiplies grads by a constant ``c``."""
if self._needs_sync:
self._multiply_factor *= c
else:
# gradients already synced to fp32 parameters, update it directly
self.fp32_optimizer.multiply_grads(c)
def per_sample_clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm."""
if max_norm <= 0.0:
return 0.0
grad_norm = self._multiply_factor * utils.clip_grad_norm_(
self.fp16_params, 0, aggregate_norm_fn
)
# grad_norm = 1.0
if grad_norm > max_norm > 0.0:
clip_coef = max_norm / (grad_norm + 1e-6)
else:
clip_coef = 1.0
self._add_fp16_grads_to_fp32(mul=clip_coef)
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm and updates dynamic loss scaler."""
self._sync_fp16_grads_to_fp32()
grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm(
0,
aggregate_norm_fn=aggregate_norm_fn,
)
if self.scaler is not None:
if grad_norm > max_norm > 0.0:
self._multiply_factor *= max_norm / grad_norm
self.scaler.check_overflow(grad_norm)
elif max_norm > 0.0:
clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1)
self._multiply_factor *= clip_coef
return grad_norm
def step(self, closure=None, groups=None):
"""Performs a single optimization step."""
self._sync_fp16_grads_to_fp32()
if getattr(self, "supports_step_with_scale", False):
self.fp32_optimizer.step(
closure, scale=(1.0 / self._multiply_factor), groups=groups
)
else:
self._unscale_grads()
self.fp32_optimizer.step(closure, groups=groups)
if self.scaler is not None:
self.scaler.update()
self._sync_fp32_params_to_fp16()
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for p in self.fp16_params:
p.grad.zero_()
if torch.is_tensor(self.fp32_params):
self.fp32_params.grad.zero_()
elif isinstance(self.fp32_params, dict):
for fp32_params in self.fp32_params.values():
fp32_params.grad.zero_()
else:
raise RuntimeError("self.fp32_params must be a tensor or dict")
self._needs_sync = False
if self.scaler is not None:
self._multiply_factor = 1.0 / float(self.scaler.loss_scale)
else:
self._multiply_factor = 1.0
class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer):
"""
Wrap an *optimizer* to support FP16 (mixed precision) training.
"""
def __init__(self, args, params, fp32_optimizer, fp32_params, **kwargs):
super().__init__(args)
self.fp16_params = params
self.fp32_optimizer = fp32_optimizer
self.fp32_params = fp32_params
self.allreduce_fp32_grad = getattr(args, "allreduce_fp32_grad", False)
if getattr(args, "fp16_scale_window", None) is None:
if len(args.update_freq) > 1:
raise ValueError(
"--fp16-scale-window must be given explicitly when using a "
"custom --update-freq schedule"
)
data_parallel_size = int(args.distributed_world_size)
scale_window = int(2**14 / data_parallel_size / args.update_freq[0])
else:
scale_window = args.fp16_scale_window
if not getattr(args, "bf16", False):
self.scaler = DynamicLossScaler(
init_scale=args.fp16_init_scale,
scale_window=scale_window,
tolerance=args.fp16_scale_tolerance,
threshold=args.threshold_loss_scale,
min_loss_scale=args.min_loss_scale,
)
else:
# disable loss scaling for bfloat16
self.scaler = None
@classmethod
def build_optimizer(cls, args, params, **kwargs):
"""
Args:
args : unicore args
params (iterable): iterable of parameters to optimize
"""
flatten = not getattr(args, "fp16_no_flatten_grads", False)
assert flatten
check_param_device(params)
params = cls.flatten_fp16_parameters(args, params)
fp32_params = cls.build_fp32_params(args, params)
fp32_optimizer = optim.build_optimizer(args, [fp32_params])
return cls(args, params, fp32_optimizer, fp32_params, **kwargs)
@property
def optimizer(self):
return self.fp32_optimizer.optimizer
@optimizer.setter
def optimizer(self, optimizer):
self.fp32_optimizer.optimizer = optimizer
@property
def lr_scheduler(self):
return getattr(self.fp32_optimizer, "lr_scheduler", None)
@property
def optimizer_config(self):
return self.fp32_optimizer.optimizer_config
def get_lr(self):
return self.fp32_optimizer.get_lr()
def set_lr(self, lr):
self.fp32_optimizer.set_lr(lr)
def all_reduce_grads(self, module):
if self.allreduce_fp32_grad and hasattr(module, "all_reduce_params"):
self._sync_fp16_grads_to_fp32()
with torch.no_grad():
params = [self.fp32_params]
module.all_reduce_params(params)
else:
self.fp32_optimizer.all_reduce_grads(module)
@property
def supports_flat_params(self):
return self.fp32_optimizer.supports_flat_params
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
def get_fused_adam_class():
try:
global unicore_fused_adam
import importlib
unicore_fused_adam = importlib.import_module("unicore_fused_adam")
return FusedAdam
except ImportError:
pass
return None
class FusedAdam(torch.optim.Optimizer):
"""
Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Compared to the original version in Apex, the unicore version casts grads
and params to FP32 internally to support ``--memory-efficient-fp16``.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED in FusedAdam!
eps_inside_sqrt (boolean, optional): in the "update parameters" step,
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params,
lr=1e-3, bias_correction=True,
betas=(0.9, 0.999), eps=1e-8,
weight_decay=0., amsgrad=False):
global unicore_fused_adam
import importlib
unicore_fused_adam = importlib.import_module("unicore_fused_adam")
if amsgrad:
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
defaults = {
"lr": lr,
"bias_correction": bias_correction,
"betas": betas,
"eps": eps,
"weight_decay": weight_decay,
}
super().__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
@property
def supports_step_with_scale(self):
return True
def step(self, closure=None, scale=1.):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
scale (float, optional): factor to divide gradient tensor values
by before applying to weights. (default: 1)
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
# compute combined scale factor for this group
combined_scale = scale
bias_correction = 1 if group.get("bias_correction", 1) else 0
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError(
"FusedAdam does not support sparse gradients, "
"please consider SparseAdam instead"
)
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p.data, dtype=torch.float)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p.data, dtype=torch.float)
else:
state["exp_avg"] = state["exp_avg"].to(dtype=torch.float)
state["exp_avg_sq"] = state["exp_avg_sq"].to(dtype=torch.float)
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
with torch.cuda.device(p.device):
unicore_fused_adam.adam(p.data,
exp_avg,
exp_avg_sq,
grad,
group["lr"],
beta1,
beta2,
group["eps"],
combined_scale,
state["step"],
bias_correction,
group["weight_decay"])
return loss
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import importlib
import os
from unicore import registry
from unicore.optim.lr_scheduler.unicore_lr_scheduler import ( # noqa
UnicoreLRScheduler,
)
(
build_lr_scheduler_,
register_lr_scheduler,
LR_SCHEDULER_REGISTRY,
) = registry.setup_registry(
"--lr-scheduler", base_class=UnicoreLRScheduler, default="fixed"
)
def build_lr_scheduler(args, optimizer, total_train_steps):
return build_lr_scheduler_(args, optimizer, total_train_steps)
# automatically import any Python files in the optim/lr_scheduler/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
file_name = file[: file.find(".py")]
importlib.import_module("unicore.optim.lr_scheduler." + file_name)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from collections.abc import Collection
from typing import List
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("cosine")
class CosineLRSchedule(UnicoreLRScheduler):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (``--warmup-init-lr``) until the configured
max learning rate (``--lr``).
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup::
lr = args.min_lr + 0.5*(args.lr - args.min_lr)*(1 + cos(t_curr / t_i))
where ``t_curr`` is current percentage of updates within the current period
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
after every iteration.
"""
def __init__(self, args, unicore_optimizer, total_train_steps):
super().__init__(args, unicore_optimizer, total_train_steps)
if isinstance(args.lr, Collection) and len(args.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with cosine."
f" Consider --lr-scheduler=fixed instead. ({args.lr})"
)
self.max_lr = args.lr[0] if isinstance(args.lr, Collection) else args.lr
assert (
self.max_lr > args.min_lr
), f"max_lr (={args.lr}) must be more than min_lr (={args.min_lr})"
warmup_end_lr = self.max_lr
if args.warmup_init_lr < 0:
args.warmup_init_lr = args.min_lr
self.t_mult = args.t_mult
self.period = args.lr_period_updates
if self.period <= 0:
assert (
args.max_update > 0
), "Either --max_update or --lr-period-updates must be set"
self.period = args.max_update - args.warmup_updates
if args.warmup_updates > 0:
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
else:
self.lr_step = 1
self.warmup_updates = args.warmup_updates
self.lr_shrink = args.lr_shrink
# initial learning rate
self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
parser.add_argument('--max-lr', type=float, metavar='LR',
help='max learning rate, must be more than args.lr')
parser.add_argument('--t-mult', default=1, type=float, metavar='LR',
help='factor to grow the length of each period')
parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR',
help='initial number of updates per period')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
# fmt: on
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
else:
curr_updates = num_updates - self.args.warmup_updates
if self.t_mult != 1:
i = math.floor(
math.log(
1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult
)
)
t_i = self.t_mult ** i * self.period
t_curr = (
curr_updates
- (1 - self.t_mult ** i) / (1 - self.t_mult) * self.period
)
else:
i = math.floor(curr_updates / self.period)
t_i = self.period
t_curr = curr_updates - (self.period * i)
lr_shrink = self.lr_shrink ** i
min_lr = self.args.min_lr * lr_shrink
max_lr = self.max_lr * lr_shrink
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (
1 + math.cos(math.pi * t_curr / t_i)
)
self.optimizer.set_lr(self.lr)
return self.lr
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("exponential_decay")
class ExponentialDecayLRSchedule(UnicoreLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
self.warmup_updates = args.warmup_updates
self.lr = args.lr[0]
if self.warmup_updates > 0:
self.warmup_factor = 1.0 / self.warmup_updates
else:
self.warmup_factor = 1.0
self.decay_ratio = args.decay_ratio
self.decay_steps = args.decay_steps
self.optimizer.set_lr(self.warmup_factor * self.lr)
self.stair_decay = getattr(args, "stair_decay", False)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--warmup-updates', default=1000, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--decay-ratio', default=0.95, type=float)
parser.add_argument('--decay-steps', default=500, type=int)
parser.add_argument('--stair-decay', action="store_true")
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
self.warmup_factor = num_updates / float(self.warmup_updates)
lr = self.warmup_factor * self.lr
else:
if self.stair_decay:
step = num_updates
lr = self.lr * float(self.decay_ratio ** (int(step // self.decay_steps)))
else:
step = num_updates - self.warmup_updates
lr = self.lr * float(self.decay_ratio ** (float(step / self.decay_steps)))
self.optimizer.set_lr(lr)
return self.optimizer.get_lr()
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("fixed")
class FixedLRSchedule(UnicoreLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
self.lr = args.lr[0]
if args.warmup_updates > 0:
self.warmup_factor = 1.0 / args.warmup_updates
else:
self.warmup_factor = 1
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
# fmt: on
def state_dict(self):
return {"lr": self.lr}
def load_state_dict(self, state_dict):
if "lr" in state_dict:
self.lr = state_dict["lr"]
def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch - 1, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = lrs[-1] * self.args.lr_shrink ** (
epoch + 1 - self.args.force_anneal
)
return next_lr
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates:
self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates)
self.optimizer.set_lr(self.warmup_factor * self.lr)
else:
self.optimizer.set_lr(self.lr)
return self.optimizer.get_lr()
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Collection
from typing import List
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("inverse_sqrt")
class InverseSquareRootSchedule(UnicoreLRScheduler):
"""Decay the LR based on the inverse square root of the update number.
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (``--warmup-init-lr``) until the configured
learning rate (``--lr``). Thereafter we decay proportional to the number of
updates, with a decay factor set to align with the configured learning rate.
During warmup::
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup::
decay_factor = args.lr * sqrt(args.warmup_updates)
lr = decay_factor / sqrt(update_num)
"""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
if isinstance(args.lr, Collection) and len(args.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with inverse_sqrt."
" Consider --lr-scheduler=fixed instead."
)
warmup_end_lr = args.lr[0] if isinstance(args.lr, Collection) else args.lr
if args.warmup_init_lr < 0:
args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr
# linearly warmup for the first args.warmup_updates
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
# then, decay prop. to the inverse square root of the update number
self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5
# initial learning rate
self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
# fmt: on
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if num_updates < self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
else:
self.lr = self.decay_factor * num_updates ** -0.5
self.optimizer.set_lr(self.lr)
return self.lr
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("pass_through")
class PassThroughScheduleSchedule(UnicoreLRScheduler):
"""Delegate lr scheduling to the optimizer."""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
assert (
hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None
), "Pass-through schedule can only be used with optimizers with their own schedulers"
def state_dict(self):
return self.optimizer.lr_scheduler.state_dict()
def load_state_dict(self, state_dict):
self.optimizer.lr_scheduler.load_state_dict(state_dict)
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
return self.optimizer.lr_scheduler.step_begin_epoch(epoch)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.lr_scheduler.step_update(num_updates)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("polynomial_decay")
class PolynomialDecayLRSchedule(UnicoreLRScheduler):
"""Decay the LR on a fixed schedule."""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
if self.args.warmup_ratio > 0:
# if warmup_ratio > 0, use external train steps
assert total_train_steps is not None
self.warmup_updates = int(self.args.warmup_ratio * total_train_steps)
self.total_num_update = total_train_steps
else:
assert args.total_num_update > 0
self.warmup_updates = args.warmup_updates
self.total_num_update = args.total_num_update
self.lr = args.lr[0]
if self.warmup_updates > 0:
self.warmup_factor = 1.0 / self.warmup_updates
else:
self.warmup_factor = 1
self.end_learning_rate = args.end_learning_rate
self.power = args.power
self.optimizer.set_lr(self.warmup_factor * self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
parser.add_argument('--force-anneal', '--fa', type=int, metavar='N',
help='force annealing at specified epoch')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-ratio', default=-1.0, type=float, metavar='N',
help='warmup the learning rate linearly for the first N-percent updates')
parser.add_argument('--end-learning-rate', default=0.0, type=float)
parser.add_argument('--power', default=1.0, type=float)
parser.add_argument('--total-num-update', default=1000000, type=int)
def get_next_lr(self, epoch):
lrs = self.args.lr
if self.args.force_anneal is None or epoch < self.args.force_anneal:
# use fixed LR schedule
next_lr = lrs[min(epoch, len(lrs) - 1)]
else:
# annneal based on lr_shrink
next_lr = self.optimizer.get_lr()
return next_lr
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
self.lr = self.get_next_lr(epoch)
self.optimizer.set_lr(self.warmup_factor * self.lr)
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
if self.warmup_updates > 0 and num_updates <= self.warmup_updates:
self.warmup_factor = num_updates / float(self.warmup_updates)
lr = self.warmup_factor * self.lr
elif num_updates >= self.total_num_update:
lr = self.end_learning_rate
else:
warmup = self.warmup_updates
lr_range = self.lr - self.end_learning_rate
pct_remaining = 1 - (num_updates - warmup) / (
self.total_num_update - warmup
)
lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate
self.optimizer.set_lr(lr)
return self.optimizer.get_lr()
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch.optim.lr_scheduler
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler(
"reduce_lr_on_plateau"
)
class ReduceLROnPlateauLRSchedule(UnicoreLRScheduler):
"""
Decay the LR by a factor every time the validation loss plateaus.
Also comes with optional warmup phase, where we linearly increase
the learning rate from some initial learning rate
(``--warmup-init-lr``) until the configured learning rate
(``--lr``). Thereafter the lr is adjusted according to original
reduce_on_plateau scheme.
During warmup::
lrs = torch.linspace(
args.warmup_init_lr, args.lr, args.warmup_updates
)
lr = lrs[update_num]
"""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
if len(args.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with reduce_lr_on_plateau."
" Consider --lr-scheduler=fixed instead."
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer,
patience=args.lr_patience,
factor=args.lr_shrink,
mode="max" if args.maximize_best_checkpoint_metric else "min",
threshold=args.lr_threshold,
)
warmup_end_lr = args.lr[0]
# if no warm up, sets initial lr to be args.lr[0]
if args.warmup_init_lr < 0:
args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr
# linearly warmup for the first args.warmup_updates
if args.warmup_updates > 0:
self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates
# this flag is either set from arg when no warm up, or set by
# step_update() when warmup finishes
self.warmup_end = True if args.warmup_updates <= 0 else False
# initial learning rate
# this self.lr is used only during init and/or warm up period
self.lr = args.warmup_init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing, lr_new = (lr * lr_shrink)')
parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT',
help='Threshold for measuring the new optimum, \
to only focus on significant changes')
parser.add_argument('--warmup-updates', default=0, type=int, metavar='N',
help='warmup the learning rate linearly for the first N updates')
parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR',
help='initial learning rate during warmup phase; default is args.lr')
# fmt: on
def state_dict(self):
"""Return the LR scheduler state dict."""
return {
"best": self.lr_scheduler.best,
"last_epoch": self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.lr_scheduler.best = state_dict["best"]
if "last_epoch" in state_dict:
self.lr_scheduler.last_epoch = state_dict["last_epoch"]
def step(self, epoch, val_loss=None):
"""
Update the learning rate at the end of the given epoch if warmup
finishes otherwise no update of lr on epoch boundaries
"""
if val_loss is not None and self.warmup_end is True:
self.lr_scheduler.step(val_loss)
else:
self.lr_scheduler.last_epoch = epoch
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""
Update the learning rate after each update."""
# if there is warmup
if self.args.warmup_updates > 0:
if num_updates <= self.args.warmup_updates:
self.lr = self.args.warmup_init_lr + num_updates * self.lr_step
self.optimizer.set_lr(self.lr)
else:
if self.warmup_end is False:
self.warmup_end = True
# else do nothing
return self.optimizer.get_lr()
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("tri_stage")
class TriStageLRSchedule(UnicoreLRScheduler):
"""Tristage learning rate schedulr
Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf
Similar to inverse_squre_root scheduler, but tri_stage learning rate employs
three stages LR scheduling:
- warmup stage, starting from `lr` * `init_lr_scale`, linearly
increased to `lr` in `warmup_steps` iterations
- hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps`
iterations
- decay stage, after hold stage, decay LR exponetially to
`lr` * `final_lr_scale` in `decay_steps`;
after that LR is keep as `final_lr_scale` * `lr`
During warmup::
init_lr = args.init_lr_scale * args.lr
lrs = torch.linspace(init_lr, args.lr, args.warmup_steps)
lr = lrs[update_num]
During hold::
lr = args.lr
During decay::
decay_factor = - math.log(args.final_lr_scale) / args.decay_steps
lr = args.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor)
After that::
lr = args.lr * args.final_lr_scale
"""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
if len(args.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with tri-stage lr."
" Consider --lr-scheduler=fixed instead."
)
# calculate LR at each point
self.peak_lr = args.lr[0]
self.init_lr = args.init_lr_scale * args.lr[0]
self.final_lr = args.final_lr_scale * args.lr[0]
if args.phase_ratio is not None:
assert args.max_update > 0
assert sum(args.phase_ratio) == 1, "phase ratios must add up to 1"
self.warmup_steps = int(args.max_update * args.phase_ratio[0])
self.hold_steps = int(args.max_update * args.phase_ratio[1])
self.decay_steps = int(args.max_update * args.phase_ratio[2])
else:
self.warmup_steps = args.warmup_steps
self.hold_steps = args.hold_steps
self.decay_steps = args.decay_steps
assert (
self.warmup_steps + self.hold_steps + self.decay_steps > 0
), "please specify steps or phase_ratio"
self.warmup_rate = (
(self.peak_lr - self.init_lr) / self.warmup_steps
if self.warmup_steps != 0
else 0
)
self.decay_factor = -math.log(args.final_lr_scale) / self.decay_steps
# initial learning rate
self.lr = self.init_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument(
'--warmup-steps',
default=4000,
type=int,
metavar='N',
help='warmup the learning rate linearly for the first N updates'
)
parser.add_argument(
'--hold-steps',
default=20000,
type=int,
metavar='N',
help='steps in hold stage.'
)
parser.add_argument(
'--decay-steps',
default=60000,
type=int,
metavar='N',
help='steps in decay stages'
)
parser.add_argument(
'--init-lr-scale',
default=0.01,
type=float,
help="""
initial learning rate scale during warmup phase; default is 0.01""")
parser.add_argument(
'--final-lr-scale',
default=0.01,
type=float,
help="final learning rate scale; default to 0.01"
)
# fmt: on
def _decide_stage(self, update_step):
"""
return stage, and the corresponding steps within the current stage
"""
if update_step < self.warmup_steps:
# warmup state
return 0, update_step
offset = self.warmup_steps
if update_step < offset + self.hold_steps:
# hold stage
return 1, update_step - offset
offset += self.hold_steps
if update_step <= offset + self.decay_steps:
# decay stage
return 2, update_step - offset
offset += self.decay_steps
# still here ? constant lr stage
return 3, update_step - offset
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
stage, steps_in_stage = self._decide_stage(num_updates)
if stage == 0:
self.lr = self.init_lr + self.warmup_rate * steps_in_stage
elif stage == 1:
self.lr = self.peak_lr
elif stage == 2:
self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
elif stage == 3:
self.lr = self.final_lr
else:
raise ValueError("Undefined stage")
self.optimizer.set_lr(self.lr)
return self.lr
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import List
from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler
@register_lr_scheduler("triangular")
class TriangularLRSchedule(UnicoreLRScheduler):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details.
"""
def __init__(self, args, optimizer, total_train_steps):
super().__init__(args, optimizer, total_train_steps)
if len(args.lr) > 1:
raise ValueError(
"Cannot use a fixed learning rate schedule with triangular."
" Consider --lr-scheduler=fixed instead."
)
lr = args.lr[0]
assert args.max_lr > lr, "max_lr must be more than lr"
self.min_lr = lr
self.max_lr = args.max_lr
self.stepsize = args.lr_period_updates // 2
self.lr_shrink = args.lr_shrink
self.shrink_min = args.shrink_min
# initial learning rate
self.lr = self.min_lr
self.optimizer.set_lr(self.lr)
@staticmethod
def add_args(parser):
"""Add arguments to the parser for this LR scheduler."""
# fmt: off
parser.add_argument('--max-lr', required=True, type=float, metavar='LR',
help='max learning rate, must be more than args.lr')
parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR',
help='initial number of updates per period (cycle length)')
parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='shrink factor for annealing')
parser.add_argument('--shrink-min', action='store_true',
help='if set, also shrinks min lr')
# fmt: on
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
super().step(epoch, val_loss)
# we don't change the learning rate at epoch boundaries
return self.optimizer.get_lr()
def step_update(self, num_updates):
"""Update the learning rate after each update."""
cycle = math.floor(num_updates / (2 * self.stepsize))
lr_shrink = self.lr_shrink ** cycle
max_lr = self.max_lr * lr_shrink
if self.shrink_min:
min_lr = self.min_lr * lr_shrink
else:
min_lr = self.min_lr
x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1)
self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x))
self.optimizer.set_lr(self.lr)
return self.lr
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from argparse import Namespace
from unicore.optim import UnicoreOptimizer
class UnicoreLRScheduler(object):
def __init__(self, args, optimizer, total_train_steps):
super().__init__()
if optimizer is not None and not isinstance(optimizer, UnicoreOptimizer):
raise ValueError("optimizer must be an instance of UnicoreOptimizer")
self.args = args
self.optimizer = optimizer
self.total_train_steps = total_train_steps
self.best = None
@classmethod
def add_args(cls, parser):
"""Add arguments to the parser for this LR scheduler."""
pass
def state_dict(self):
"""Return the LR scheduler state dict."""
return {"best": self.best}
def load_state_dict(self, state_dict):
"""Load an LR scheduler state dict."""
self.best = state_dict["best"]
def step_begin_epoch(self, epoch):
"""Update the learning rate at the beginning of the given epoch."""
pass
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None:
if self.best is None:
self.best = val_loss
else:
self.best = min(self.best, val_loss)
def step_update(self, num_updates):
"""Update the learning rate after each update."""
return self.optimizer.get_lr()
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch.optim
from . import UnicoreOptimizer, register_optimizer
@register_optimizer("sgd")
class SGD(UnicoreOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = torch.optim.SGD(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--momentum', default=0.0, type=float, metavar='M',
help='momentum factor')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.args.lr[0],
"momentum": self.args.momentum,
"weight_decay": self.args.weight_decay,
}
@property
def supports_flat_params(self):
return True
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
from unicore import utils
class UnicoreOptimizer(object):
def __init__(self, args):
super().__init__()
self.args = args
self._grad_buffer = None
self._need_sync_grad_buf = False
@classmethod
def add_args(cls, parser):
"""Add optimizer-specific arguments to the parser."""
pass
@property
def optimizer(self):
"""Return a torch.optim.optimizer.Optimizer instance."""
if not hasattr(self, "_optimizer"):
raise NotImplementedError
if not isinstance(self._optimizer, torch.optim.Optimizer):
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
return self._optimizer
@optimizer.setter
def optimizer(self, optimizer):
"""Reset optimizer instance."""
if not hasattr(self, "_optimizer"):
raise NotImplementedError
if not isinstance(self._optimizer, torch.optim.Optimizer):
raise ValueError("_optimizer must be an instance of torch.optim.Optimizer")
self._optimizer = optimizer
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise NotImplementedError
@property
def params(self):
"""Return an iterable of the parameters held by the optimizer."""
for param_group in self.param_groups:
for p in param_group["params"]:
yield p
@property
def param_groups(self):
return self.optimizer.param_groups
def __getstate__(self):
return self._optimizer.__getstate__()
def get_lr(self):
"""Return the current learning rate."""
return self.param_groups[0]["lr"]
def set_lr(self, lr):
"""Set the learning rate."""
for param_group in self.param_groups:
param_group["lr"] = lr
def state_dict(self):
"""Return the optimizer's state dict."""
return self.optimizer.state_dict()
def load_state_dict(self, state_dict, optimizer_overrides=None):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self.optimizer.load_state_dict(state_dict)
if optimizer_overrides is not None and len(optimizer_overrides) > 0:
# override learning rate, momentum, etc. with latest values
for group in self.param_groups:
group.update(optimizer_overrides)
def backward(self, loss):
"""Computes the sum of gradients of the given tensor w.r.t. graph leaves."""
loss.backward()
def all_reduce_grads(self, module):
"""Manually all-reduce gradients (if required)."""
self.__sync_grad_from_buf__()
if hasattr(module, "all_reduce_grads"):
module.all_reduce_grads()
def multiply_grads(self, c):
"""Multiplies grads by a constant *c*."""
for p in self.params:
if p.grad is not None:
if torch.is_tensor(c):
c = c.to(p.grad.device)
p.grad.data.mul_(c)
def per_sample_clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm."""
if max_norm <= 0.0:
return 0.0
if self._grad_buffer is None:
self._grad_buffer = [torch.zeros_like(g) for g in self.params]
gnorm = utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)
for i, p in enumerate(self.params):
if p.grad is None:
continue
self._grad_buffer[i] += p.grad
p.grad = None
self._need_sync_grad_buf = True
return gnorm
def __sync_grad_from_buf__(self):
if self._need_sync_grad_buf:
assert self._grad_buffer is not None
for i, p in enumerate(self.params):
p.grad = self._grad_buffer[i]
self._need_sync_grad_buf = False
def clip_grad_norm(self, max_norm, aggregate_norm_fn=None):
"""Clips gradient norm."""
self.__sync_grad_from_buf__()
return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn)
def step(self, closure=None, scale=1.0, groups=None):
"""Performs a single optimization step."""
self.__sync_grad_from_buf__()
if self.supports_step_with_scale:
if self.supports_groups:
self.optimizer.step(closure, scale=scale, groups=groups)
else:
self.optimizer.step(closure, scale=scale)
else:
if scale != 1.0:
self.multiply_grads(1.0 / scale)
if self.supports_groups:
self.optimizer.step(closure, groups=groups)
else:
self.optimizer.step(closure)
def zero_grad(self):
"""Clears the gradients of all optimized parameters."""
for p in self.params:
p.grad = None
self.optimizer.zero_grad()
self._need_sync_grad_buf = False
if self._grad_buffer is not None:
for t in self._grad_buffer:
t.zero_()
@property
def supports_memory_efficient_fp16(self):
if hasattr(self.optimizer, "supports_memory_efficient_fp16"):
return self.optimizer.supports_memory_efficient_fp16
return False
@property
def supports_step_with_scale(self):
if hasattr(self.optimizer, "supports_step_with_scale"):
return self.optimizer.supports_step_with_scale
return False
@property
def supports_groups(self):
if hasattr(self.optimizer, "supports_groups"):
return self.optimizer.supports_groups
return False
@property
def supports_flat_params(self):
"""
Whether the optimizer supports collapsing of the model
parameters/gradients into a single contiguous Tensor.
"""
if hasattr(self.optimizer, "supports_flat_params"):
return self.optimizer.supports_flat_params
return False
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import torch
from typing import Callable, List, Optional
# this import is for backward compatibility
from unicore.utils import csv_str_list, eval_bool, eval_str_dict, eval_str_list, import_user_module # noqa
def get_training_parser(default_task="translation"):
parser = get_parser("Trainer", default_task)
add_dataset_args(parser, train=True)
add_distributed_training_args(parser)
add_model_args(parser)
add_optimization_args(parser)
add_checkpoint_args(parser)
return parser
def get_validation_parser(default_task=None):
parser = get_parser("Validation", default_task)
add_dataset_args(parser, train=True)
add_distributed_training_args(parser)
group = parser.add_argument_group("Evaluation")
add_common_eval_args(group)
return parser
def parse_args_and_arch(
parser: argparse.ArgumentParser,
input_args: List[str] = None,
parse_known: bool = False,
suppress_defaults: bool = False,
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None,
):
"""
Args:
parser (ArgumentParser): the parser
input_args (List[str]): strings to parse, defaults to sys.argv
parse_known (bool): only parse known arguments, similar to
`ArgumentParser.parse_known_args`
suppress_defaults (bool): parse while ignoring all default values
modify_parser (Optional[Callable[[ArgumentParser], None]]):
function to modify the parser, e.g., to set default values
"""
if suppress_defaults:
# Parse args without any default values. This requires us to parse
# twice, once to identify all the necessary task/model args, and a second
# time with all defaults set to None.
args = parse_args_and_arch(
parser,
input_args=input_args,
parse_known=parse_known,
suppress_defaults=False,
)
suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser])
suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()})
args = suppressed_parser.parse_args(input_args)
return argparse.Namespace(
**{k: v for k, v in vars(args).items() if v is not None}
)
from unicore.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY
# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
usr_parser.add_argument("--user-dir", default=None)
usr_args, _ = usr_parser.parse_known_args(input_args)
import_user_module(usr_args)
if modify_parser is not None:
modify_parser(parser)
# The parser doesn't know about model/loss/optimizer-specific args, so
# we parse twice. First we parse the model/loss/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args, _ = parser.parse_known_args(input_args)
# Add model-specific args to parser.
if hasattr(args, "arch"):
model_specific_group = parser.add_argument_group(
"Model-specific configuration",
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default=argparse.SUPPRESS,
)
if args.arch in ARCH_MODEL_REGISTRY:
ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group)
elif args.arch in MODEL_REGISTRY:
MODEL_REGISTRY[args.arch].add_args(model_specific_group)
else:
raise RuntimeError()
if hasattr(args, "task"):
from unicore.tasks import TASK_REGISTRY
TASK_REGISTRY[args.task].add_args(parser)
# Add *-specific args to parser.
from unicore.registry import REGISTRIES
for registry_name, REGISTRY in REGISTRIES.items():
choice = getattr(args, registry_name, None)
if choice is not None:
cls = REGISTRY["registry"][choice]
if hasattr(cls, "add_args"):
cls.add_args(parser)
# Modify the parser a second time, since defaults may have been reset
if modify_parser is not None:
modify_parser(parser)
# Parse a second time.
if parse_known:
args, extra = parser.parse_known_args(input_args)
else:
args = parser.parse_args(input_args)
extra = None
# Post-process args.
if (
hasattr(args, "batch_size_valid") and args.batch_size_valid is None
) or not hasattr(args, "batch_size_valid"):
args.batch_size_valid = args.batch_size
args.bf16 = getattr(args, "bf16", False)
if getattr(args, "seed", None) is None:
args.seed = 1 # default seed for training
args.no_seed_provided = True
else:
args.no_seed_provided = False
# Apply architecture configuration.
if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY:
ARCH_CONFIG_REGISTRY[args.arch](args)
if parse_known:
return args, extra
else:
return args
def get_parser(desc, default_task='test'):
# Before creating the true parser, we need to import optional user module
# in order to eagerly import custom tasks, optimizers, architectures, etc.
usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
usr_parser.add_argument('--user-dir', default=None)
usr_args, _ = usr_parser.parse_known_args()
import_user_module(usr_args)
parser = argparse.ArgumentParser(allow_abbrev=False)
# fmt: off
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N batches (when progress bar is disabled)')
parser.add_argument('--log-format', default=None, help='log format to use',
choices=['json', 'none', 'simple', 'tqdm'])
parser.add_argument('--tensorboard-logdir', metavar='DIR', default='',
help='path to save logs for tensorboard, should match --logdir '
'of running tensorboard (default: no tensorboard logging)')
parser.add_argument('--seed', default=1, type=int, metavar='N',
help='pseudo random number generator seed')
parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA')
parser.add_argument('--fp16', action='store_true', help='use FP16')
parser.add_argument('--bf16', action='store_true', help='use BF16')
parser.add_argument('--bf16-sr', action='store_true', help='use stachostic rounding for bf16')
parser.add_argument('--allreduce-fp32-grad', action='store_true', help='use fp32-grads in fp16/bf16 mode. --ddp-backend should be no_c10d')
parser.add_argument('--fp16-no-flatten-grads', action='store_true', help="don't flatten FP16 grads tensor")
parser.add_argument('--fp16-init-scale', default=2 ** 7, type=int,
help='default FP16 loss scale')
parser.add_argument('--fp16-scale-window', type=int,
help='number of updates before increasing loss scale')
parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float,
help='pct of updates that can overflow before decreasing the loss scale')
parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D',
help='minimum FP16 loss scale, after which training is stopped')
parser.add_argument('--threshold-loss-scale', type=float,
help='threshold FP16 loss scale from below')
parser.add_argument('--user-dir', default=None,
help='path to a python module containing custom extensions (tasks and/or architectures)')
parser.add_argument('--empty-cache-freq', default=0, type=int,
help='how often to clear the PyTorch CUDA cache (0 to disable)')
parser.add_argument('--all-gather-list-size', default=16384, type=int,
help='number of bytes reserved for gathering stats from workers')
parser.add_argument('--suppress-crashes', action='store_true', help="suppress crashes when training with the entry point so that the "
"main method can return a value (useful for sweeps)")
parser.add_argument('--profile', action='store_true', help="enable autograd profiler emit_nvtx")
parser.add_argument('--ema-decay', default=-1.0, type=float, help="enable moving average for model weights")
from unicore.registry import REGISTRIES
for registry_name, REGISTRY in REGISTRIES.items():
parser.add_argument(
'--' + registry_name.replace('_', '-'),
default=REGISTRY['default'],
choices=REGISTRY['registry'].keys(),
)
# Task definitions can be found under unicore/tasks/
from unicore.tasks import TASK_REGISTRY
parser.add_argument('--task', metavar='TASK', default=default_task,
choices=TASK_REGISTRY.keys(),
help='task')
# fmt: on
return parser
def add_dataset_args(parser, train=False, gen=False):
group = parser.add_argument_group('Dataset and data loading')
# fmt: off
group.add_argument('--num-workers', default=1, type=int, metavar='N',
help='how many subprocesses to use for data loading')
group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true',
help='ignore too long or too short lines in valid and test set')
group.add_argument('--batch-size', '--max-sentences', type=int, metavar='N',
help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value')
group.add_argument('--data-buffer-size', default=10, type=int,
help='Number of batches to preload')
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test', 'train.small'],
help='data subset to use for training (train, valid, test)')
group.add_argument('--valid-subset', default='valid', metavar='SPLIT',
help='comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)')
group.add_argument('--validate-interval', type=int, default=1, metavar='N',
help='validate every N epochs')
group.add_argument('--validate-interval-updates', type=int, default=0, metavar='N',
help='validate every N updates')
group.add_argument('--validate-after-updates', type=int, default=0, metavar='N',
help='dont validate until reaching this many updates')
group.add_argument('--fixed-validation-seed', default=None, type=int, metavar='N',
help='specified random seed for validation')
group.add_argument('--disable-validation', action='store_true',
help='disable validation')
group.add_argument('--batch-size-valid', type=int, metavar='N',
help='maximum number of sentences in a validation batch'
' (defaults to --max-sentences)')
group.add_argument('--max-valid-steps', type=int, metavar='N',
help='How many batches to evaluate')
group.add_argument('--curriculum', default=0, type=int, metavar='N',
help='don\'t shuffle batches for first N epochs')
# fmt: on
return group
def add_distributed_training_args(parser):
group = parser.add_argument_group('Distributed training')
# fmt: off
group.add_argument('--distributed-world-size', type=int, metavar='N',
default=max(1, torch.cuda.device_count()),
help='total number of GPUs across all nodes (default: all visible GPUs)')
group.add_argument('--distributed-rank', default=0, type=int,
help='rank of the current worker')
group.add_argument('--distributed-backend', default='nccl', type=str,
help='distributed backend')
group.add_argument('--distributed-init-method', default=None, type=str,
help='typically tcp://hostname:port that will be used to '
'establish initial connetion')
group.add_argument('--distributed-port', default=-1, type=int,
help='port number (not required if using --distributed-init-method)')
group.add_argument('--device-id', '--local_rank', default=0, type=int,
help='which GPU to use (usually configured automatically)')
group.add_argument('--distributed-no-spawn', action='store_true',
help='do not spawn multiple processes even if multiple GPUs are visible')
group.add_argument('--ddp-backend', default='c10d', type=str,
choices=['c10d', 'apex', 'no_c10d'],
help='DistributedDataParallel backend')
group.add_argument('--bucket-cap-mb', default=25, type=int, metavar='MB',
help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true',
help='don\'t shuffle batches between GPUs; this reduces overall '
'randomness and may affect precision but avoids the cost of '
're-reading the data')
group.add_argument('--find-unused-parameters', default=False, action='store_true',
help='disable unused parameter detection (not applicable to '
'no_c10d ddp-backend')
group.add_argument('--fast-stat-sync', default=False, action='store_true',
help='Enable fast sync of stats between nodes, this hardcodes to '
'sync only some default stats from logging_output.')
group.add_argument('--broadcast-buffers', default=False, action='store_true',
help="Copy non-trainable parameters between GPUs, such as "
"batchnorm population statistics")
group.add_argument('--nprocs-per-node', default=max(1, torch.cuda.device_count()), type=int,
help="number of GPUs in each node. An allreduce operation across GPUs in "
"a node is very fast. Hence, we do allreduce across GPUs in a node, "
"and gossip across different nodes")
# fmt: on
return group
def add_optimization_args(parser):
group = parser.add_argument_group('Optimization')
# fmt: off
group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N',
help='force stop training at specified epoch')
group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N',
help='force stop training at specified update')
group.add_argument('--stop-time-hours', default=0, type=float,
help="force stop training after specified cumulative time (if >0)")
group.add_argument('--clip-norm', default=0, type=float, metavar='NORM',
help='clip threshold of gradients')
group.add_argument('--per-sample-clip-norm', default=0, type=float, metavar='PNORM',
help='clip threshold of gradients, before gradient sync over workers. In fp16/bf16 mode, --fp32-grad should be set, and --dpp-backend should be no_c10d')
group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K',
type=lambda uf: eval_str_list(uf, type=int),
help='update parameters every N_i batches, when in epoch i')
group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list,
metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)')
group.add_argument('--stop-min-lr', default=-1, type=float, metavar='LR',
help='stop training when the learning rate reaches this minimum')
# fmt: on
return group
def add_checkpoint_args(parser):
group = parser.add_argument_group('Checkpointing')
# fmt: off
group.add_argument('--save-dir', metavar='DIR', default='checkpoints',
help='path to save checkpoints')
group.add_argument('--tmp-save-dir', metavar='DIR', default='./',
help='path to temporarily save checkpoints')
group.add_argument('--restore-file', default='checkpoint_last.pt',
help='filename from which to load checkpoint '
'(default: <save-dir>/checkpoint_last.pt')
group.add_argument('--finetune-from-model', type=str,
help="finetune from a pretrained model; note that meters and lr scheduler will be reset")
group.add_argument('--load-from-ema', action="store_true",
help="finetune from a pretrained model; note that meters and lr scheduler will be reset")
group.add_argument('--reset-dataloader', action='store_true',
help='if set, does not reload dataloader state from the checkpoint')
group.add_argument('--reset-lr-scheduler', action='store_true',
help='if set, does not load lr scheduler state from the checkpoint')
group.add_argument('--reset-meters', action='store_true',
help='if set, does not load meters from the checkpoint')
group.add_argument('--reset-optimizer', action='store_true',
help='if set, does not load optimizer state from the checkpoint')
group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override optimizer args when loading a checkpoint')
group.add_argument('--save-interval', type=int, default=1, metavar='N',
help='save a checkpoint every N epochs')
group.add_argument('--save-interval-updates', type=int, default=0, metavar='N',
help='save a checkpoint (and validate) every N updates')
group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N',
help='keep the last N checkpoints saved with --save-interval-updates')
group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N',
help='keep last N epoch checkpoints')
group.add_argument('--keep-best-checkpoints', type=int, default=-1, metavar='N',
help='keep best N checkpoints based on scores')
group.add_argument('--no-save', action='store_true',
help='don\'t save models or checkpoints')
group.add_argument('--no-epoch-checkpoints', action='store_true',
help='only store last and best checkpoints')
group.add_argument('--no-last-checkpoints', action='store_true',
help='don\'t store last checkpoints')
group.add_argument('--no-save-optimizer-state', action='store_true',
help='don\'t save optimizer-state as part of checkpoint')
group.add_argument('--best-checkpoint-metric', type=str, default='loss',
help='metric to use for saving "best" checkpoints')
group.add_argument('--maximize-best-checkpoint-metric', action='store_true',
help='select the largest metric value for saving "best" checkpoints')
group.add_argument('--patience', type=int, default=-1, metavar='N',
help="early stop training if valid performance doesn't "
"improve for N consecutive validation runs; note "
"that this is influenced by --validate-interval")
group.add_argument('--checkpoint-suffix', type=str, default="",
help="suffix to add to the checkpoint file name")
# fmt: on
return group
def add_common_eval_args(group):
# fmt: off
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation '
'that were used during model training')
group.add_argument('--results-path', metavar='RESDIR', type=str, default=None,
help='path to save eval results (optional)"')
# fmt: on
def add_model_args(parser):
group = parser.add_argument_group('Model configuration')
# fmt: off
# Model definitions can be found under unicore/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
# 3) --encoder/decoder-* arguments (highest priority)
from unicore.models import ARCH_MODEL_REGISTRY
group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(),
help='Model Architecture')
# fmt: on
return group
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
REGISTRIES = {}
def setup_registry(
registry_name: str,
base_class=None,
default=None,
):
assert registry_name.startswith('--')
registry_name = registry_name[2:].replace('-', '_')
REGISTRY = {}
REGISTRY_CLASS_NAMES = set()
# maintain a registry of all registries
if registry_name in REGISTRIES:
return # registry already exists
REGISTRIES[registry_name] = {
'registry': REGISTRY,
'default': default,
}
def build_x(args, *extra_args, **extra_kwargs):
choice = getattr(args, registry_name, None)
if choice is None:
return None
cls = REGISTRY[choice]
if hasattr(cls, 'build_' + registry_name):
builder = getattr(cls, 'build_' + registry_name)
else:
builder = cls
set_defaults(args, cls)
return builder(args, *extra_args, **extra_kwargs)
def register_x(name):
def register_x_cls(cls):
if name in REGISTRY:
raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name))
if cls.__name__ in REGISTRY_CLASS_NAMES:
raise ValueError(
'Cannot register {} with duplicate class name ({})'.format(
registry_name, cls.__name__,
)
)
if base_class is not None and not issubclass(cls, base_class):
raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__))
REGISTRY[name] = cls
REGISTRY_CLASS_NAMES.add(cls.__name__)
return cls
return register_x_cls
return build_x, register_x, REGISTRY
def set_defaults(args, cls):
"""Helper to set default arguments based on *add_args*."""
if not hasattr(cls, 'add_args'):
return
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False)
cls.add_args(parser)
# copied from argparse.py:
defaults = argparse.Namespace()
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if not hasattr(defaults, action.dest):
if action.default is not argparse.SUPPRESS:
setattr(defaults, action.dest, action.default)
for key, default_value in vars(defaults).items():
if not hasattr(args, key):
setattr(args, key, default_value)
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import argparse
import importlib
import os
from .unicore_task import UnicoreTask
# register dataclass
TASK_REGISTRY = {}
TASK_CLASS_NAMES = set()
def setup_task(args, **kwargs):
return TASK_REGISTRY[args.task].setup_task(args, **kwargs)
def register_task(name):
"""
New tasks can be added to unicore with the
:func:`~unicore.tasks.register_task` function decorator.
For example::
@register_task('classification')
class ClassificationTask(UnicoreTask):
(...)
.. note::
All Tasks must implement the :class:`~unicore.tasks.UnicoreTask`
interface.
Args:
name (str): the name of the task
"""
def register_task_cls(cls):
if name in TASK_REGISTRY:
raise ValueError("Cannot register duplicate task ({})".format(name))
if not issubclass(cls, UnicoreTask):
raise ValueError(
"Task ({}: {}) must extend UnicoreTask".format(name, cls.__name__)
)
if cls.__name__ in TASK_CLASS_NAMES:
raise ValueError(
"Cannot register task with duplicate class name ({})".format(
cls.__name__
)
)
TASK_REGISTRY[name] = cls
TASK_CLASS_NAMES.add(cls.__name__)
return cls
return register_task_cls
# automatically import any Python files in the tasks/ directory
tasks_dir = os.path.dirname(__file__)
for file in os.listdir(tasks_dir):
path = os.path.join(tasks_dir, file)
if (
not file.startswith("_")
and not file.startswith(".")
and (file.endswith(".py") or os.path.isdir(path))
):
task_name = file[: file.find(".py")] if file.endswith(".py") else file
module = importlib.import_module("unicore.tasks." + task_name)
# expose `task_parser` for sphinx
if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group("Task name")
# fmt: off
group_task.add_argument('--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``')
# fmt: on
group_args = parser.add_argument_group("Additional command-line arguments")
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + "_parser"] = parser
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import warnings
from argparse import Namespace
from typing import Any, Callable, Dict, List
import torch
from unicore import metrics, utils
from unicore.data import UnicoreDataset, data_utils, iterators
logger = logging.getLogger(__name__)
class StatefulContainer(object):
_state: Dict[str, Any] = dict()
_factories: Dict[str, Callable[[], Any]] = dict()
def add_factory(self, name, factory: Callable[[], Any]):
self._factories[name] = factory
def merge_state_dict(self, state_dict: Dict[str, Any]):
self._state.update(state_dict)
@property
def state_dict(self) -> Dict[str, Any]:
return self._state
def __getattr__(self, name):
if name not in self._state and name in self._factories:
self._state[name] = self._factories[name]()
if name in self._state:
return self._state[name]
raise AttributeError(f"Task state has no factory for attribute {name}")
class UnicoreTask(object):
"""
Tasks store dictionaries and provide helpers for loading/iterating over
Datasets, initializing the Model/Loss and calculating the loss.
Tasks have limited statefulness. In particular, state that needs to be
saved to/loaded from checkpoints needs to be stored in the `self.state`
:class:`StatefulContainer` object. For example::
self.state.add_factory("dictionary", self.load_dictionary)
print(self.state.dictionary) # calls self.load_dictionary()
This is necessary so that when loading checkpoints, we can properly
recreate the task state after initializing the task instance.
"""
@classmethod
def add_args(cls, parser):
"""Add task-specific arguments to the parser."""
pass
@staticmethod
def logging_outputs_can_be_summed(loss, is_train) -> bool:
"""
Whether the logging outputs returned by `train_step` and `valid_step` can
be summed across workers prior to calling `reduce_metrics`.
Setting this to True will improves distributed training speed.
"""
return loss.logging_outputs_can_be_summed(is_train)
args: Namespace
datasets: Dict[str, UnicoreDataset]
dataset_to_epoch_iter: Dict[UnicoreDataset, Any]
state: StatefulContainer = None
def __init__(self, args: Namespace, **kwargs):
self.args = args
self.datasets = dict()
self.dataset_to_epoch_iter = dict()
self.state = StatefulContainer()
@classmethod
def setup_task(cls, args: Namespace, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (Namespace): parsed command-line arguments
"""
return cls(args, **kwargs)
def has_sharded_data(self, split):
return os.pathsep in getattr(self.args, "data", "")
def load_dataset(
self,
split: str,
combine: bool = False,
**kwargs
):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
combine (bool): combines a split segmented into pieces into one dataset
"""
raise NotImplementedError
def dataset(self, split):
"""
Return a loaded dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
Returns:
a :class:`~unicore.data.UnicoreDataset` corresponding to *split*
"""
from unicore.data import UnicoreDataset
if split not in self.datasets:
raise KeyError("Dataset not loaded: " + split)
if not isinstance(self.datasets[split], UnicoreDataset):
raise TypeError("Datasets are expected to be of type UnicoreDataset")
return self.datasets[split]
def can_reuse_epoch_itr(self, dataset):
# We can reuse the epoch iterator across epochs as long as the dataset
# hasn't disabled it. We default to ``False`` here, although in practice
# this will be ``True`` for most datasets that inherit from
# ``UnicoreDataset`` due to the base implementation there.
return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False)
def get_batch_iterator(
self,
dataset,
batch_size=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
"""
Get an iterator that yields batches of data from the given dataset.
Args:
dataset (~unicore.data.UnicoreDataset): dataset to batch
batch_size (int, optional): max number of samples in each
batch (default: None).
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long (default: False).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N (default: 1).
seed (int, optional): seed for random number generator for
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
data_buffer_size (int, optional): number of batches to
preload (default: 0).
disable_iterator_cache (bool, optional): don't cache the
EpochBatchIterator (ignores `UnicoreTask::can_reuse_epoch_itr`)
(default: False).
Returns:
~unicore.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
"""
can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr(
dataset
)
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter:
logger.info("reusing EpochBatchIterator for epoch {}".format(epoch))
return self.dataset_to_epoch_iter[dataset]
else:
logger.info("get EpochBatchIterator for epoch {}".format(epoch))
assert isinstance(dataset, UnicoreDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# create mini-batches with given size constraints
batch_sampler = dataset.batch_by_size(
indices,
batch_size=batch_size,
required_batch_size_multiple=required_batch_size_multiple,
)
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size,
disable_shuffling=self.disable_shuffling(),
)
if can_reuse_epoch_itr:
self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_iter
def build_model(self, args: Namespace):
"""
Build the :class:`~unicore.models.BaseUnicoreModel` instance for this
task.
Returns:
a :class:`~unicore.models.BaseUnicoreModel` instance
"""
from unicore import models
return models.build_model(args, self)
def build_loss(self, args: Namespace):
"""
Build the :class:`~unicore.losses.UnicoreLoss` instance for
this task.
Args:
args (Namespace): configration object
Returns:
a :class:`~unicore.losses.UnicoreLoss` instance
"""
from unicore import losses
return losses.build_loss(args, self)
def train_step(
self, sample, model, loss, optimizer, update_num, ignore_grad=False
):
"""
Do forward and backward, and return the loss as computed by *loss*
for the given *model* and *sample*.
Args:
sample (dict): the mini-batch. The format is defined by the
:class:`~unicore.data.UnicoreDataset`.
model (~unicore.models.BaseUnicoreModel): the model
loss (~unicore.losses.UnicoreLoss): the loss
optimizer (~unicore.optim.UnicoreOptimizer): the optimizer
update_num (int): the current update
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
- the loss
- the sample size, which is used as the denominator for the
gradient
- logging outputs to display while training
"""
model.train()
model.set_num_updates(update_num)
with torch.autograd.profiler.record_function("forward"):
loss, sample_size, logging_output = loss(model, sample)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
def valid_step(self, sample, model, loss, test=False):
model.eval()
with torch.no_grad():
loss, sample_size, logging_output = loss(model, sample)
return loss, sample_size, logging_output
def optimizer_step(self, optimizer, model, update_num):
optimizer.step()
def build_dataset_for_inference(
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs
) -> torch.utils.data.Dataset:
raise NotImplementedError
def begin_epoch(self, epoch, model):
"""Hook function called before the start of each epoch."""
pass
def begin_valid_epoch(self, epoch, model):
"""Hook function called before the start of each validation epoch."""
pass
def reduce_metrics(self, logging_outputs, loss, split='train'):
"""Aggregate logging outputs from data parallel training."""
if not any("bsz" in log for log in logging_outputs):
warnings.warn(
"bsz not found in Loss logging outputs, cannot log bsz"
)
else:
bsz = sum(log.get("bsz", 0) for log in logging_outputs)
metrics.log_scalar("bsz", bsz, priority=190, round=1)
loss.__class__.reduce_metrics(logging_outputs, split)
def state_dict(self):
if self.state is not None:
return self.state.state_dict
return {}
def load_state_dict(self, state_dict: Dict[str, Any]):
if self.state is not None:
self.state.merge_state_dict(state_dict)
def disable_shuffling(self) -> bool:
return False
\ No newline at end of file
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
import contextlib
import logging
import os
import sys
import time
from itertools import chain
from typing import Any, Dict, List
import torch
from unicore import checkpoint_utils, models, optim, utils
from unicore.distributed import utils as distributed_utils
from unicore.logging import meters, metrics
from unicore.nan_detector import NanDetector
from unicore.optim import lr_scheduler
from unicore.utils import tensor_tree_map
logger = logging.getLogger(__name__)
class ExponentialMovingAverage:
"""
Maintains moving averages of parameters with exponential decay
At each step, the stored copy `copy` of each parameter `param` is
updated as follows:
`copy = decay * copy + (1 - decay) * param`
where `decay` is an attribute of the ExponentialMovingAverage object.
"""
def __init__(self, model: torch.nn.Module, decay: float):
"""
Args:
model:
A torch.nn.Module whose parameters are to be tracked
decay:
A value (usually close to 1.) by which updates are
weighted as part of the above formula
"""
super(ExponentialMovingAverage, self).__init__()
with torch.no_grad():
clone_param = lambda t: t.clone().detach().float()
self.params = tensor_tree_map(clone_param, model.state_dict())
self.decay = decay
def _update_state_dict_(self, update, state_dict):
with torch.no_grad():
for k, v in update.items():
if state_dict[k].device != v.device:
state_dict[k] = state_dict[k].to(v.device)
stored = state_dict[k]
if not isinstance(v, torch.Tensor):
self._update_state_dict_(v, stored)
else:
diff = stored - v.float()
diff *= 1 - self.decay
stored -= diff
def update(self, model: torch.nn.Module) -> None:
"""
Updates the stored parameters using the state dict of the provided
module. The module should have the same structure as that used to
initialize the ExponentialMovingAverage object.
"""
self._update_state_dict_(model.state_dict(), self.params)
def load_state_dict(self, state_dict: dict) -> None:
self.params = state_dict["params"]
self.decay = state_dict["decay"] if "decay" in state_dict else self.decay
def state_dict(self) -> dict:
return {
"params": self.params,
"decay": self.decay,
}
class Trainer(object):
"""Main class for data parallel training.
This class supports synchronous distributed data parallel training,
where multiple workers each have a full model replica and gradients
are accumulated across workers before each update. We use
:class:`~torch.nn.parallel.DistributedDataParallel` to handle
communication of the gradients across workers.
"""
def __init__(self, args, task, model, loss):
self.args = args
self.task = task
# catalog shared parameters
shared_params = _catalog_shared_params(model)
self.cuda = torch.cuda.is_available()
if self.cuda:
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
# copy model and loss to current device/dtype
self._loss = loss
self._model = model
if args.fp16:
self._loss = self._loss.half()
self._model = self._model.half()
elif args.bf16:
self._loss = self._loss.bfloat16()
self._model = self._model.bfloat16()
if (
# the DistributedUnicoreModel wrapper will handle moving to device,
# so only handle cases which don't use the wrapper
not self.use_distributed_wrapper
):
self._loss = self._loss.to(device=self.device)
self._model = self._model.to(device=self.device)
# check that shared parameters are preserved after device transfer
for shared_param in shared_params:
ref = _get_module_by_path(self._model, shared_param[0])
for path in shared_param[1:]:
logger.info(
"detected shared parameter: {} <- {}".format(shared_param[0], path)
)
_set_module_by_path(self._model, path, ref)
self._dummy_batch = None # indicates we don't have a dummy batch at first
self._total_train_steps = None
self._lr_scheduler = None
self._num_updates = 0
self._optim_history = None
self._optimizer = None
self._warn_once = set()
self._wrapped_loss = None
self._wrapped_model = None
if self.cuda and self.data_parallel_world_size > 1:
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size)
else:
self._grad_norm_buf = None
# get detailed cuda environment
if self.cuda:
self.cuda_env = utils.CudaEnvironment()
if self.data_parallel_world_size > 1:
self.cuda_env_arr = distributed_utils.all_gather_list(
self.cuda_env, group=distributed_utils.get_global_group()
)
else:
self.cuda_env_arr = [self.cuda_env]
if self.data_parallel_rank == 0:
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr)
else:
self.cuda_env = None
self.cuda_env_arr = None
# add ema
if args.ema_decay > 0 and self.data_parallel_rank == 0:
self.ema = ExponentialMovingAverage(self._model, decay=args.ema_decay)
else:
self.ema = None
metrics.log_start_time("wall", priority=790, round=2)
self._start_time = time.time()
self._previous_training_time = 0
self._cumulative_training_time = None
def reinitialize(self):
"""Reinitialize the Trainer, typically after model params change."""
self._lr_scheduler = None
self._optimizer = None
self._wrapped_loss = None
self._wrapped_model = None
@property
def data_parallel_world_size(self):
if self.args.distributed_world_size == 1:
return 1
return distributed_utils.get_data_parallel_world_size()
@property
def data_parallel_process_group(self):
return distributed_utils.get_data_parallel_group()
@property
def data_parallel_rank(self):
if self.args.distributed_world_size == 1:
return 0
return distributed_utils.get_data_parallel_rank()
@property
def is_data_parallel_master(self):
# NOTE: this returns true for all model parallel replicas with data
# parallel rank 0
return self.data_parallel_rank == 0
@property
def use_distributed_wrapper(self) -> bool:
return self.data_parallel_world_size > 1
@property
def should_save_checkpoint_on_current_rank(self) -> bool:
"""Indicates whether to save checkpoints on the current DDP rank."""
return self.is_data_parallel_master
@property
def checkpoint_suffix(self) -> str:
"""Suffix to add to the checkpoint file name."""
return self.args.checkpoint_suffix or ""
@property
def loss(self):
if self._wrapped_loss is None:
if utils.has_parameters(self._loss) and self.use_distributed_wrapper:
self._wrapped_loss = models.DistributedUnicoreModel(
self.args,
self._loss,
process_group=self.data_parallel_process_group,
device=self.device,
)
else:
self._wrapped_loss = self._loss
return self._wrapped_loss
@property
def model(self):
if self._wrapped_model is None:
if self.use_distributed_wrapper:
self._wrapped_model = models.DistributedUnicoreModel(
self.args,
self._model,
process_group=self.data_parallel_process_group,
device=self.device,
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def optimizer(self):
if self._optimizer is None:
self._build_optimizer()
return self._optimizer
@property
def lr_scheduler(self):
if self._lr_scheduler is None:
self._build_optimizer() # this will initialize self._lr_scheduler
return self._lr_scheduler
def _build_optimizer(self):
params = list(
filter(
lambda p: p.requires_grad,
chain(self.model.parameters(), self.loss.parameters()),
)
)
if self.args.per_sample_clip_norm > 0:
assert self.args.ddp_backend == "no_c10d"
assert self.args.batch_size == 1
if self.args.fp16 or self.args.bf16:
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7:
logger.info(
"NOTE: your device does NOT support faster training with --fp16, "
"please switch to FP32 which is likely to be faster"
)
self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params)
if self.args.allreduce_fp32_grad:
assert self.args.ddp_backend == "no_c10d"
if self.args.per_sample_clip_norm > 0:
assert self.args.allreduce_fp32_grad
else:
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7:
logger.info("NOTE: your device may support faster training with --fp16")
self._optimizer = optim.build_optimizer(self.args, params)
# We should initialize the learning rate scheduler immediately after
# building the optimizer, so that the initial learning rate is set.
self._lr_scheduler = lr_scheduler.build_lr_scheduler(
self.args,
self.optimizer,
self._total_train_steps,
)
self._lr_scheduler.step_update(0)
def state_dict(self):
state_dict = {
"args": self.args,
"model": self.model.state_dict(),
"loss": (
self.loss.state_dict() if utils.has_parameters(self.loss) else None
),
"optimizer_history": (self._optim_history or [])
+ [
{
"loss_name": self.get_loss().__class__.__name__,
"optimizer_name": self.optimizer.__class__.__name__,
"lr_scheduler_state": self.lr_scheduler.state_dict(),
"num_updates": self.get_num_updates(),
}
],
"task_state": self.task.state_dict() if self.task is not None else {},
"extra_state": {
"metrics": metrics.state_dict(),
"previous_training_time": self.cumulative_training_time(),
},
}
if not self.args.no_save_optimizer_state:
state_dict["last_optimizer_state"] = self.optimizer.state_dict()
if self.ema is not None:
state_dict["ema"] = self.ema.state_dict()
return state_dict
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
logger.info(f"Saving checkpoint to {filename}")
# call state_dict on all ranks in case it needs internal communication
state_dict = utils.move_to_cpu(self.state_dict())
state_dict["extra_state"].update(extra_state)
if self.should_save_checkpoint_on_current_rank:
checkpoint_utils.torch_persistent_save(
state_dict,
filename,
)
logger.info(f"Finished saving checkpoint to {filename}")
def load_checkpoint(
self,
filename,
reset_optimizer=False,
reset_lr_scheduler=False,
optimizer_overrides=None,
reset_meters=False,
):
"""
Load all training state from a checkpoint file.
rank = 0 will load the checkpoint, and then broadcast it to all
other ranks.
"""
extra_state, self._optim_history, last_optim_state = None, [], None
logger.info(f"Preparing to load checkpoint {filename}")
is_distributed = self.data_parallel_world_size > 1
is_master = self.data_parallel_rank == 0
bexists = None
if is_master:
bexists = os.path.isfile(filename)
if is_distributed:
bexists = distributed_utils.broadcast_object(
bexists,
src_rank=0,
group=self.data_parallel_process_group,
dist_device=self.device,
)
had_loaded_model = False
if bexists:
state = None
if is_master:
state = checkpoint_utils.load_checkpoint_to_cpu(
filename,
)
if is_distributed:
logger.info("Broadcast checkpoint from rank_0")
state = distributed_utils.broadcast_object(
state,
src_rank=0,
group=self.data_parallel_process_group,
dist_device=self.device,
)
last_optim_state = state.get("last_optimizer_state", None)
ema_state = state.get("ema", None)
# load model parameters
try:
if self.args.load_from_ema:
logger.info("loading ema state to model")
errors = self.model.load_state_dict(
ema_state["params"], strict=False, model_args=self.args
)
else:
errors = self.model.load_state_dict(
state["model"], strict=False, model_args=self.args
)
# save memory for later steps
del state["model"]
had_loaded_model = True
if errors.missing_keys:
logger.warning(
"Error in loading model state, missing_keys "
+ str(errors.missing_keys)
)
if errors.unexpected_keys:
logger.warning(
"Error in loading model state, unexpected_keys "
+ str(errors.unexpected_keys)
)
if utils.has_parameters(self.get_loss()):
self.get_loss().load_state_dict(state["loss"], strict=True)
del state["loss"]
except Exception:
raise Exception(
"Cannot load model parameters from checkpoint {}; "
"please ensure that the architectures match.".format(filename)
)
extra_state = state["extra_state"] if "extra_state" in state else None
self._optim_history = (
state["optimizer_history"] if "optimizer_history" in state else None
)
if (
ema_state is not None
and self.ema is not None
and not self.args.load_from_ema
):
logger.info(f"Loading EMA state...")
self.ema.load_state_dict(ema_state)
elif self.ema is not None:
logger.info(
f"Cannot find EMA state in checkpoint, load model weight to ema directly"
)
self.ema = ExponentialMovingAverage(self._model, decay=self.ema.decay)
if last_optim_state is not None and not reset_optimizer:
# rebuild optimizer after loading model, since params may have changed
self._build_optimizer()
# only reload optimizer and lr_scheduler if they match
last_optim = self._optim_history[-1]
assert (
last_optim["loss_name"] == self.get_loss().__class__.__name__
), f"Loss does not match; please reset the optimizer (--reset-optimizer). {last_optim['loss_name']} vs {self.get_loss().__class__.__name__}"
assert (
last_optim["optimizer_name"] == self.optimizer.__class__.__name__
), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}"
if not reset_lr_scheduler:
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"])
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides)
self.set_num_updates(last_optim["num_updates"])
if extra_state is not None:
itr_state = extra_state["train_iterator"]
epoch = itr_state["epoch"]
if "previous_training_time" in extra_state:
self._previous_training_time = extra_state["previous_training_time"]
self._start_time = time.time()
# self.lr_step(epoch)
if (
itr_state.get("version", 1) >= 2
and itr_state["iterations_in_epoch"] == 0
):
# reset meters at start of epoch
reset_meters = True
if "metrics" in extra_state and not reset_meters:
metrics.load_state_dict(extra_state["metrics"])
# reset TimeMeters, since their start times don't make sense anymore
for meter in metrics.get_meters("default"):
if isinstance(meter, meters.TimeMeter):
meter.reset()
logger.info(
"Loaded checkpoint {} (epoch {} @ {} updates)".format(
filename, epoch, self.get_num_updates()
)
)
elif had_loaded_model:
logger.info("Loaded checkpoint {}".format(filename))
else:
logger.info("No existing checkpoint found {}".format(filename))
return extra_state
def get_train_iterator(
self,
epoch,
combine=True,
load_dataset=True,
data_selector=None,
shard_batch_itr=True,
disable_iterator_cache=False,
):
"""Return an EpochBatchIterator over the training set for a given epoch."""
if load_dataset:
logger.info("loading train data for epoch {}".format(epoch))
self.task.load_dataset(
self.args.train_subset,
epoch=epoch,
combine=combine,
data_selector=data_selector,
)
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(self.args.train_subset),
batch_size=self.args.batch_size,
ignore_invalid_inputs=True,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
num_shards=self.data_parallel_world_size if shard_batch_itr else 1,
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
num_workers=self.args.num_workers,
epoch=epoch,
data_buffer_size=self.args.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
)
self.reset_dummy_batch(batch_iterator.first_batch)
return batch_iterator
def init_total_train_steps(self, epoch_itr):
if self.args.max_epoch > 0:
self._total_train_steps = (
(len(epoch_itr) + 1) // self.args.update_freq[0] * self.args.max_epoch
)
else:
self._total_train_steps = self.args.max_update
def get_valid_iterator(
self,
subset,
disable_iterator_cache=False,
):
"""Return an EpochBatchIterator over given validation subset for a given epoch."""
batch_iterator = self.task.get_batch_iterator(
dataset=self.task.dataset(subset),
batch_size=self.args.batch_size_valid,
ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=self.args.required_batch_size_multiple,
seed=self.args.seed,
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
num_workers=self.args.num_workers,
# always pass a fixed "epoch" to keep validation data consistent
# across training epochs
epoch=1,
data_buffer_size=self.args.data_buffer_size,
disable_iterator_cache=disable_iterator_cache,
)
self.reset_dummy_batch(batch_iterator.first_batch)
return batch_iterator
def begin_epoch(self, epoch):
"""Called at the beginning of each epoch."""
logger.info("begin training epoch {}".format(epoch))
self.lr_step_begin_epoch(epoch)
# task specific setup per epoch
self.task.begin_epoch(epoch, self.get_model())
def begin_valid_epoch(self, epoch):
"""Called at the beginning of each validation epoch."""
# task specific setup per validation epoch
self.task.begin_valid_epoch(epoch, self.get_model())
def reset_dummy_batch(self, batch):
self._dummy_batch = batch
@metrics.aggregate("train")
def train_step(self, samples, raise_oom=False):
"""Do forward, backward and parameter update."""
self.model.train()
self.loss.train()
self.zero_grad()
metrics.log_start_time("train_wall", priority=800, round=2)
# forward and backward pass
logging_outputs, sample_size, ooms = [], 0, 0
for i, sample in enumerate(samples): # delayed update loop
sample, is_dummy_batch = self._prepare_sample(sample)
def maybe_no_sync():
"""
Whenever *samples* contains more than one mini-batch, we
want to accumulate gradients locally and only call
all-reduce in the last backwards pass.
"""
if (
self.data_parallel_world_size > 1
and hasattr(self.model, "no_sync")
and i < len(samples) - 1
):
return self.model.no_sync()
else:
return contextlib.ExitStack() # dummy contextmanager
try:
with maybe_no_sync():
# use different seed for different rank in training, otherwise the dropout will be the same in different workers.
with utils.torch_seed(
self.args.seed,
self.get_num_updates(),
i,
self.data_parallel_rank,
):
# forward and backward
loss, sample_size_i, logging_output = self.task.train_step(
sample=sample,
model=self.model,
loss=self.loss,
optimizer=self.optimizer,
update_num=self.get_num_updates(),
ignore_grad=is_dummy_batch,
)
del loss
if self.args.per_sample_clip_norm > 0:
self.optimizer.per_sample_clip_grad_norm(
self.args.per_sample_clip_norm
)
logging_outputs.append(logging_output)
sample_size += sample_size_i
# emptying the CUDA cache after the first step can
# reduce the chance of OOM
if self.cuda and self.get_num_updates() == 0:
torch.cuda.empty_cache()
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
if raise_oom:
raise e
logger.warning(
"attempting to recover from OOM in forward/backward pass"
)
ooms += 1
self.zero_grad()
if self.cuda:
torch.cuda.empty_cache()
if self.args.distributed_world_size == 1:
return None
else:
raise e
if is_dummy_batch:
if torch.is_tensor(sample_size):
sample_size.zero_()
else:
sample_size *= 0.0
if torch.is_tensor(sample_size):
sample_size = sample_size.float()
else:
sample_size = float(sample_size)
local_sample_size = sample_size
# gather logging outputs from all replicas
if self._sync_stats():
train_time = self._local_cumulative_training_time()
logging_outputs, (
sample_size,
ooms,
total_train_time,
) = self._aggregate_logging_outputs(
logging_outputs,
sample_size,
ooms,
train_time,
ignore=is_dummy_batch,
is_train=True,
)
self._cumulative_training_time = (
total_train_time / self.data_parallel_world_size
)
overflow = False
try:
with torch.autograd.profiler.record_function("reduce-grads"):
# reduce gradients across workers
self.optimizer.all_reduce_grads(self.model)
if utils.has_parameters(self.loss):
self.optimizer.all_reduce_grads(self.loss)
with torch.autograd.profiler.record_function("multiply-grads"):
# multiply gradients by (data_parallel_size / sample_size) since
# DDP normalizes by the number of data parallel workers for
# improved fp16 precision.
# Thus we get (sum_of_gradients / sample_size) at the end.
# In case of fp16, this step also undoes loss scaling.
# (Debugging note: Some optimizers perform this scaling on the
# fly, so inspecting model.parameters() or optimizer.params may
# still show the original, unscaled gradients.)
numer = self.data_parallel_world_size if self._sync_stats() else 1
self.optimizer.multiply_grads(numer / (sample_size or 1.0))
# Note: (sample_size or 1.0) handles the case of a zero gradient, in a
# way that avoids CPU/device transfers in case sample_size is a GPU or
# TPU object. The assumption is that the gradient itself is also 0.
with torch.autograd.profiler.record_function("clip-grads"):
# clip grads
grad_norm = self.clip_grad_norm(self.args.clip_norm)
self._check_grad_norms(grad_norm)
if not torch.isfinite(grad_norm).all():
# check local gradnorm single GPU case, trigger NanDetector
raise FloatingPointError("gradients are Nan/Inf")
with torch.autograd.profiler.record_function("optimizer"):
# fixed the seed in case for the stochastic rounding in different ranks
with utils.torch_seed(self.args.seed, self.get_num_updates()):
# take an optimization step
self.task.optimizer_step(
self.optimizer,
model=self.model,
update_num=self.get_num_updates(),
)
if self.ema is not None:
with torch.autograd.profiler.record_function("ema"):
self.ema.update(self.model)
except FloatingPointError:
# re-run the forward and backward pass with hooks attached to print
# out where it fails
self.zero_grad()
with NanDetector(self.get_model()):
for i, sample in enumerate(samples):
sample, _ = self._prepare_sample(sample)
with utils.torch_seed(
self.args.seed,
self.get_num_updates(),
i,
self.data_parallel_rank,
):
self.task.train_step(
sample,
self.model,
self.loss,
self.optimizer,
self.get_num_updates(),
ignore_grad=False,
)
raise
except OverflowError as e:
overflow = True
logger.info(
f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
)
grad_norm = torch.tensor(0.0).cuda()
self.zero_grad()
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
logger.error("OOM during optimization, irrecoverable")
raise e
logging_output = None
if not overflow:
self.set_num_updates(self.get_num_updates() + 1)
if self.cuda and self.cuda_env is not None:
# log minimum free memory over the iteration
gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
torch.cuda.reset_peak_memory_stats()
gb_free = self.cuda_env.total_memory_in_GB - gb_used
metrics.log_scalar("gb_free", gb_free, priority=1500, round=1, weight=0)
# log stats
logging_output = self._reduce_and_log_stats(
logging_outputs,
sample_size,
grad_norm,
)
# clear CUDA cache to reduce memory fragmentation
if (
self.cuda
and self.args.empty_cache_freq > 0
and (
(self.get_num_updates() + self.args.empty_cache_freq - 1)
% self.args.empty_cache_freq
)
== 0
):
torch.cuda.empty_cache()
if self.args.fp16:
metrics.log_scalar(
"loss_scale",
self.optimizer.scaler.loss_scale,
priority=700,
round=4,
weight=0,
)
metrics.log_stop_time("train_wall")
return logging_output
@metrics.aggregate("valid")
def valid_step(self, sample, raise_oom=False):
"""Do forward pass in evaluation mode."""
with torch.no_grad():
self.model.eval()
self.loss.eval()
sample, is_dummy_batch = self._prepare_sample(sample)
try:
_loss, sample_size, logging_output = self.task.valid_step(
sample, self.model, self.loss
)
except RuntimeError as e:
if "out of memory" in str(e):
self._log_oom(e)
if not raise_oom:
logger.warning(
"ran out of memory in validation step, retrying batch"
)
for p in self.model.parameters():
if p.grad is not None:
p.grad = None # free some memory
if self.cuda:
torch.cuda.empty_cache()
return self.valid_step(sample, raise_oom=True)
raise e
logging_outputs = [logging_output]
if is_dummy_batch:
if torch.is_tensor(sample_size):
sample_size.zero_()
else:
sample_size *= 0.0
# gather logging outputs from all replicas
if self.data_parallel_world_size > 1:
logging_outputs, (sample_size,) = self._aggregate_logging_outputs(
logging_outputs,
sample_size,
ignore=is_dummy_batch,
is_train=False,
)
return logging_outputs
def zero_grad(self):
self.optimizer.zero_grad()
def lr_step_begin_epoch(self, epoch):
"""Adjust the learning rate at the beginning of the epoch."""
self.lr_scheduler.step_begin_epoch(epoch)
# prefer updating the LR based on the number of steps
return self.lr_step_update()
def lr_step(self, epoch, val_loss=None):
"""Adjust the learning rate at the end of the epoch."""
self.lr_scheduler.step(epoch, val_loss)
# prefer updating the LR based on the number of steps
return self.lr_step_update()
def lr_step_update(self):
"""Update the learning rate after each update."""
new_lr = self.lr_scheduler.step_update(self.get_num_updates())
if isinstance(new_lr, dict):
for k, v in new_lr.items():
metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300)
new_lr = new_lr.get("default", next(iter(new_lr.values())))
else:
metrics.log_scalar("lr", new_lr, weight=0, priority=300)
return new_lr
def get_lr(self):
"""Get the current learning rate."""
return self.optimizer.get_lr()
def get_model(self):
"""Get the (non-wrapped) model instance."""
return self._model
def get_loss(self):
"""Get the (non-wrapped) loss instance."""
return self._loss
def get_num_updates(self):
"""Get the number of parameters updates."""
return self._num_updates
def set_num_updates(self, num_updates):
"""Set the number of parameters updates."""
self._num_updates = num_updates
self.lr_step_update()
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200)
def clip_grad_norm(self, clip_norm):
return self.optimizer.clip_grad_norm(clip_norm)
def cumulative_training_time(self):
if self._cumulative_training_time is None:
# single GPU
return self._local_cumulative_training_time()
else:
return self._cumulative_training_time
def _local_cumulative_training_time(self):
"""Aggregate training time in seconds."""
return time.time() - self._start_time + self._previous_training_time
def _prepare_sample(self, sample, is_dummy=False):
if sample == "DUMMY":
raise Exception(
"Trying to use an uninitialized 'dummy' batch. This usually indicates "
"that the total number of batches is smaller than the number of "
"participating GPUs. Try reducing the batch size or using fewer GPUs."
)
if sample is None or len(sample) == 0:
assert (
self._dummy_batch is not None and len(self._dummy_batch) > 0
), "Invalid dummy batch: {}".format(self._dummy_batch)
sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True)
return sample, True
if self.cuda:
sample = utils.move_to_cuda(sample)
def apply_half(t):
if t.dtype is torch.float32:
return t.half()
return t
def apply_bfloat16(t):
if t.dtype is torch.float32:
return t.to(dtype=torch.bfloat16)
return t
# Please manually convert data type by yourself.
# if self.args.fp16:
# sample = utils.apply_to_sample(apply_half, sample)
# if self.args.bf16:
# sample = utils.apply_to_sample(apply_bfloat16, sample)
if self._dummy_batch == "DUMMY":
self._dummy_batch = sample
return sample, False
def _sync_stats(self):
# Return True if it's using multiple GPUs and DDP or multiple GPUs with
if self.data_parallel_world_size == 1:
return False
else:
return True
def _log_oom(self, exc):
msg = "OOM: Ran out of memory with exception: {}".format(exc)
logger.warning(msg)
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
for device_idx in range(torch.cuda.device_count()):
logger.warning(torch.cuda.memory_summary(device=device_idx))
sys.stderr.flush()
def _aggregate_logging_outputs(
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
is_train=False,
):
if self.task.__class__.logging_outputs_can_be_summed(
self.get_loss(), is_train=is_train
):
return self._fast_stat_sync_sum(
logging_outputs, *extra_stats_to_sum, ignore=ignore
)
else:
return self._all_gather_list_sync(
logging_outputs, *extra_stats_to_sum, ignore=ignore
)
def _all_gather_list_sync(
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
):
"""
Sync logging outputs across workers. all_gather_list_sync is
suitable when logging outputs are complex types.
"""
if ignore:
logging_outputs = []
results = list(
zip(
*distributed_utils.all_gather_list(
[logging_outputs] + list(extra_stats_to_sum),
max_size=getattr(self.args, "all_gather_list_size", 16384),
group=self.data_parallel_process_group,
)
)
)
logging_outputs, extra_stats_to_sum = results[0], results[1:]
logging_outputs = list(chain.from_iterable(logging_outputs))
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
return logging_outputs, extra_stats_to_sum
def _fast_stat_sync_sum(
self,
logging_outputs: List[Dict[str, Any]],
*extra_stats_to_sum,
ignore=False,
):
"""
Sync logging outputs across workers. fast_stat_sync_sum is
faster than all_gather_list_sync, but is only suitable when
logging outputs are scalars and can be summed. Note that
*logging_outputs* cannot contain any nested dicts/lists.
"""
data = {}
for i, stat in enumerate(extra_stats_to_sum):
data["extra_stats_" + str(i)] = stat
if len(logging_outputs) > 0:
log_keys = list(logging_outputs[0].keys())
for k in log_keys:
if not ignore:
v = sum(log[k] for log in logging_outputs if k in log)
else:
v = logging_outputs[0][k]
v = torch.zeros_like(v) if torch.is_tensor(v) else 0
data["logging_outputs_" + k] = v
else:
log_keys = None
data = distributed_utils.all_reduce_dict(
data, device=self.device, group=self.data_parallel_process_group
)
extra_stats_to_sum = [
data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum))
]
if log_keys is not None:
logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}]
else:
logging_outputs = []
return logging_outputs, extra_stats_to_sum
def _check_grad_norms(self, grad_norm):
"""Check that grad norms are consistent across workers."""
if self._grad_norm_buf is not None:
self._grad_norm_buf.zero_()
self._grad_norm_buf[self.data_parallel_rank] = grad_norm
distributed_utils.all_reduce(
self._grad_norm_buf, group=self.data_parallel_process_group
)
def is_consistent(tensor):
max_abs_diff = torch.max(torch.abs(tensor - tensor[0]))
return (
torch.isfinite(tensor).all()
and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all()
)
if not is_consistent(self._grad_norm_buf):
pretty_detail = "\n".join(
"rank {:3d} = {:.8f}".format(r, n)
for r, n in enumerate(self._grad_norm_buf.tolist())
)
error_detail = "grad_norm across the workers:\n{}\n".format(
pretty_detail
)
# use FloatingPointError to trigger NanDetector
raise FloatingPointError(
"Fatal error: gradients are inconsistent between workers. "
"Try --ddp-backend=legacy_ddp. "
"Or are you mixing up different generation of GPUs in training?"
+ "\n"
+ "-" * 80
+ "\n{}\n".format(error_detail)
+ "-" * 80
)
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None):
if grad_norm is not None and (
not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm)
):
metrics.log_speed("ups", 1.0, priority=100, round=2)
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3)
if self.args.clip_norm > 0:
metrics.log_scalar(
"clip",
torch.where(
grad_norm > self.args.clip_norm,
grad_norm.new_tensor(100),
grad_norm.new_tensor(0),
),
priority=500,
round=1,
)
with metrics.aggregate() as agg:
if logging_outputs is not None:
self.task.reduce_metrics(logging_outputs, self.get_loss())
del logging_outputs
# extra warning for losses that don't properly log a loss value
if "loss" not in agg:
if "loss" not in self._warn_once:
self._warn_once.add("loss")
logger.warning(
"Loss.reduce_metrics did not log a 'loss' value, "
"which may break some functionality"
)
metrics.log_scalar("loss", -1)
logging_output = agg.get_smoothed_values()
logging_output["sample_size"] = sample_size
for key_to_delete in ["ppl", "wps", "wpb", "bsz"]:
if key_to_delete in logging_output:
del logging_output[key_to_delete]
return logging_output
def _catalog_shared_params(module, memo=None, prefix=""):
if memo is None:
first_call = True
memo = {}
else:
first_call = False
for name, param in module._parameters.items():
if param is None:
continue
param_prefix = prefix + ("." if prefix else "") + name
if param not in memo:
memo[param] = []
memo[param].append(param_prefix)
for name, m in module._modules.items():
if m is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
_catalog_shared_params(m, memo, submodule_prefix)
if first_call:
return [x for x in memo.values() if len(x) > 1]
def _get_module_by_path(module, path):
path = path.split(".")
for name in path:
module = getattr(module, name)
return module
def _set_module_by_path(module, path, value):
path = path.split(".")
for name in path[:-1]:
module = getattr(module, name)
setattr(module, path[-1], value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment