Commit 4d6ed501 authored by Deyu Fu's avatar Deyu Fu
Browse files

Merge branch 'multi_tensor_sgd' into deyuf/fused_optimizer_v2

parents 690b1f71 9f64bf27
import torch
import xentropy_cuda
class SoftmaxCrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
losses, max_log_sum_exp = xentropy_cuda.forward(
logits, labels, smoothing, half_to_float)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels,
torch.FloatTensor([smoothing]),
torch.LongTensor([padding_idx]))
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==padding_idx.item(), 0)
grad_logits = xentropy_cuda.backward(
grad_loss.contiguous(), logits, max_log_sum_exp,
labels, smoothing.item())
return grad_logits, None, None, None, None
from .sgd import FusedSGD from .fused_sgd import FusedSGD
from .novograd import FusedNovoGrad from .novograd import FusedNovoGrad
from .fused_adam_v1 import FusedAdam_v1 from .fused_adam_v1 import FusedAdam_v1
from .adam import FusedAdam from .adam import FusedAdam
#from .sgd import FusedSGD
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
...@@ -2,8 +2,9 @@ import types ...@@ -2,8 +2,9 @@ import types
import torch import torch
import importlib import importlib
class FusedAdam_v1(torch.optim.Optimizer): from ..multi_tensor_apply import multi_tensor_applier
class FusedAdam_v1(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
...@@ -25,6 +26,8 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -25,6 +26,8 @@ class FusedAdam_v1(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -35,10 +38,21 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -35,10 +38,21 @@ class FusedAdam_v1(torch.optim.Optimizer):
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False): weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -66,6 +80,12 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -66,6 +80,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms
if grads is None: if grads is None:
grads_group = [None]*len(self.param_groups) grads_group = [None]*len(self.param_groups)
# backward compatibility # backward compatibility
...@@ -105,6 +125,12 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -105,6 +125,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
...@@ -130,18 +156,43 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -130,18 +156,43 @@ class FusedAdam_v1(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data, if self._use_multi_tensor:
out_p, pl = [p.data, exp_avg, exp_avg_sq, grad]
exp_avg, if output_param is not None:
exp_avg_sq, pl.append(out_p)
grad,
group['lr'], for tl, t in zip(tensorlists, pl):
beta1, tl.append(t)
beta2, else:
group['eps'], fused_adam_cuda.adam(p.data,
combined_scale, out_p,
state['step'], exp_avg,
self.eps_mode, exp_avg_sq,
bias_correction, grad,
group['weight_decay']) group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
import torch
from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs = [True, True]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params:
stash = self._amp_stash
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum,
1.0/self.most_recent_scale)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
...@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None): ...@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
if call is dist.all_reduce: if call is dist.all_reduce:
coalesced /= dist.get_world_size() coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, unflatten(coalesced, bucket)): for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced) buf.copy_(synced)
...@@ -54,7 +54,7 @@ def split_half_float_double(tensors): ...@@ -54,7 +54,7 @@ def split_half_float_double(tensors):
for i, dtype in enumerate(dtypes): for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype] bucket = [t for t in tensors if t.type() == dtype]
if bucket: if bucket:
buckets.append(bucket) buckets.append(bucket)
return buckets return buckets
def split_by_type(tensors): def split_by_type(tensors):
...@@ -69,12 +69,12 @@ def split_by_type(tensors): ...@@ -69,12 +69,12 @@ def split_by_type(tensors):
# flat_dist_call organizes 'tensors' by type. # flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None): def flat_dist_call(tensors, call, extra_args=None):
buckets = split_by_type(tensors) buckets = split_by_type(tensors)
for tp in buckets: for tp in buckets:
bucket = buckets[tp] bucket = buckets[tp]
apply_flat_dist_call(bucket, call, extra_args) apply_flat_dist_call(bucket, call, extra_args)
def extract_tensors(maybe_tensor, tensor_list): def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor) tensor_list.append(maybe_tensor)
...@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list): ...@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list):
except TypeError: except TypeError:
return return
class Reducer(object): class Reducer(object):
""" """
:class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters :class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters
...@@ -93,13 +93,13 @@ class Reducer(object): ...@@ -93,13 +93,13 @@ class Reducer(object):
Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce
parameters during ``backward()``. parameters during ``backward()``.
Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually. Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.
This enables, for example, delaying the allreduce to be carried out every This enables, for example, delaying the allreduce to be carried out every
several iterations instead of every single iteration. several iterations instead of every single iteration.
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
over the number of participating processes. over the number of participating processes.
:class:`Reducer` is designed to work with the upstream launch utility script :class:`Reducer` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs. When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
...@@ -107,7 +107,7 @@ class Reducer(object): ...@@ -107,7 +107,7 @@ class Reducer(object):
Args: Args:
module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training. module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.
""" """
def __init__(self, module_or_grads_list): def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module): if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list self.module = module_or_grads_list
...@@ -117,26 +117,26 @@ class Reducer(object): ...@@ -117,26 +117,26 @@ class Reducer(object):
self.module = None self.module = None
self.grads = [] self.grads = []
extract_tensors(module_or_grads_list, self.grads) extract_tensors(module_or_grads_list, self.grads)
def reduce(self): def reduce(self):
if self.module: if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce) flat_dist_call(grads, dist.all_reduce)
else: else:
flat_dist_call(self.grads, dist.all_reduce) flat_dist_call(self.grads, dist.all_reduce)
class DistributedDataParallel(Module): class DistributedDataParallel(Module):
""" """
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are
allreduced and averaged over processes during ``backward()``. allreduced and averaged over processes during ``backward()``.
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by :class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
overlapping communication with computation during ``backward()`` and bucketing smaller gradient overlapping communication with computation during ``backward()`` and bucketing smaller gradient
transfers to reduce the total number of transfers required. transfers to reduce the total number of transfers required.
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script :class:`DistributedDataParallel` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs. When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
...@@ -159,19 +159,23 @@ class DistributedDataParallel(Module): ...@@ -159,19 +159,23 @@ class DistributedDataParallel(Module):
""" """
def __init__(self, def __init__(self,
module, module,
message_size=10000000, message_size=10000000,
delay_allreduce=False, delay_allreduce=False,
shared_param=None, shared_param=None,
allreduce_trigger_params=None, allreduce_trigger_params=None,
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
num_allreduce_streams=1,
allreduce_communicators=None,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0): gradient_predivide_factor=1.0,
gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around # Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and
# https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86 # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86
if hasattr(dist, "get_backend"): if hasattr(dist, "get_backend"):
...@@ -181,13 +185,26 @@ class DistributedDataParallel(Module): ...@@ -181,13 +185,26 @@ class DistributedDataParallel(Module):
else: else:
self.backend_enum_holder = dist.Backend self.backend_enum_holder = dist.Backend
else: else:
self._backend = dist._backend self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
self.allreduce_different_streams = (num_allreduce_streams > 1)
self.num_allreduce_streams = num_allreduce_streams
self.allreduce_communicators = allreduce_communicators
if self.allreduce_communicators:
assert len(allreduce_communicators[0]) == num_allreduce_streams
assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])
assert self.allreduce_different_streams
if self.allreduce_different_streams and delay_allreduce:
raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
self.world_size = float(dist.get_world_size()) self.world_size = float(dist.get_world_size())
...@@ -199,27 +216,29 @@ class DistributedDataParallel(Module): ...@@ -199,27 +216,29 @@ class DistributedDataParallel(Module):
self.custom_allreduce_triggers = False self.custom_allreduce_triggers = False
if allreduce_trigger_params is not None: if allreduce_trigger_params is not None:
if delay_allreduce: if delay_allreduce:
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.") raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
self.custom_allreduce_triggers = True self.custom_allreduce_triggers = True
self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params]) self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])
self.delay_allreduce = delay_allreduce self.delay_allreduce = delay_allreduce
self.message_size = message_size self.message_size = message_size
self.reduction_stream = torch.cuda.Stream() self.main_stream = torch.cuda.current_stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.bucket_streams = []
self.bucket_events = []
self.module = module self.module = module
self._disable_allreduce = False self._disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.active_params = [] self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1, "torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2} "torch.cuda.DoubleTensor" : 2}
...@@ -236,15 +255,21 @@ class DistributedDataParallel(Module): ...@@ -236,15 +255,21 @@ class DistributedDataParallel(Module):
def __setstate__(self, state): def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state) super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream() if self.allreduce_different_streams and delay_allreduce:
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self): def __getstate__(self):
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL: if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream'] del attrs['self.bucket_streams']
del attrs['self.reduction_event'] del attrs['self.bucket_events']
return attrs return attrs
def enable_allreduce(self): def enable_allreduce(self):
...@@ -252,9 +277,9 @@ class DistributedDataParallel(Module): ...@@ -252,9 +277,9 @@ class DistributedDataParallel(Module):
def disable_allreduce(self): def disable_allreduce(self):
self._disable_allreduce = True self._disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes # Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match. # regenerate their bucket structures to match.
def sync_bucket_structure(self): def sync_bucket_structure(self):
# Append leftover buckets # Append leftover buckets
for tmp_bucket in self.tmp_buckets: for tmp_bucket in self.tmp_buckets:
...@@ -264,8 +289,8 @@ class DistributedDataParallel(Module): ...@@ -264,8 +289,8 @@ class DistributedDataParallel(Module):
self.num_buckets = len(self.active_i_buckets) self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets] self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] + info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes + self.bucket_sizes +
list(chain(*self.active_i_buckets))) list(chain(*self.active_i_buckets)))
dist.broadcast(info_tensor, 0) dist.broadcast(info_tensor, 0)
...@@ -273,27 +298,27 @@ class DistributedDataParallel(Module): ...@@ -273,27 +298,27 @@ class DistributedDataParallel(Module):
info = [int(entry) for entry in info_tensor] info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0] self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1] self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to # Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well. # keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])] self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:] flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0 flat_i = 0
for bucket_idx in range(self.num_buckets): for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]): for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i] param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc) self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1 flat_i += 1
def create_hooks(self): def create_hooks(self):
# Fallback hook that's only called at the end of backward. # Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the # Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later # bucket structure that will be used to overlap communication with computation in later
# iterations. # iterations.
def allreduce_params(): def allreduce_params():
...@@ -308,9 +333,10 @@ class DistributedDataParallel(Module): ...@@ -308,9 +333,10 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue(): def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event) for stream, event in zip(self.bucket_streams, self.bucket_events):
torch.cuda.current_stream().wait_event(self.reduction_event) stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off # Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets: if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format( raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
...@@ -320,7 +346,7 @@ class DistributedDataParallel(Module): ...@@ -320,7 +346,7 @@ class DistributedDataParallel(Module):
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes): for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected: if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.") raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = [] self.grad_accs = []
for param in self.module.parameters(): for param in self.module.parameters():
...@@ -330,6 +356,9 @@ class DistributedDataParallel(Module): ...@@ -330,6 +356,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if not self._disable_allreduce: if not self._disable_allreduce:
if self.delay_allreduce or self.needs_refresh: if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between # TODO: How do we want to handle multiple backward passes between
...@@ -341,8 +370,8 @@ class DistributedDataParallel(Module): ...@@ -341,8 +370,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately. # Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()] current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i) self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False ship_tmp_bucket = False
if self.custom_allreduce_triggers: if self.custom_allreduce_triggers:
...@@ -359,82 +388,133 @@ class DistributedDataParallel(Module): ...@@ -359,82 +388,133 @@ class DistributedDataParallel(Module):
self.active_i_buckets.append(self.tmp_buckets[current_type]) self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = [] self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0 self.tmp_numels[current_type] = 0
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True self.callback_queued = True
else: else:
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue) Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True self.callback_queued = True
self.comm_ready_buckets(param) self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
wrapper(param) wrapper(param)
def allreduce_bucket(self, bucket):
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):
tensor = flatten(bucket) tensor = flatten(bucket)
tensor_to_allreduce = tensor if force_default_stream:
bucket_stream = self.main_stream
else:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
tensor_to_allreduce = tensor
if self.allreduce_always_fp32: if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float() tensor_to_allreduce = tensor.float()
if self.gradient_predivide_factor != 1.0: if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce) if self.allreduce_different_streams and not force_default_stream:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average: if self.gradient_average:
if self.gradient_predivide_factor != self.world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size) tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce) tensor.copy_(tensor_to_allreduce)
if not self.retain_allreduce_buffers:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(tensor, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced)
# I think we actually do need this here. After allreduce_bucket returns, tensor will
# eventually go out of scope and die, at which point it could otherwise be freed for
# further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.
tensor.record_stream(bucket_stream)
return tensor return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket) def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):
allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None: if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.") "allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
else: for view, grad in zip(unflatten(allreduced, bucket), bucket):
if multi_tensor_applier.available: grad.data = view
multi_tensor_applier( # for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
self.multi_tensor_scale, # buf.copy_(synced)
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self): def allreduce_fallback(self):
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False, # If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the # this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless. # training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))] self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets): for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i) allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)
def comm_ready_buckets(self, param): def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR. # Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream()) # self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)] bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
...@@ -442,39 +522,46 @@ class DistributedDataParallel(Module): ...@@ -442,39 +522,46 @@ class DistributedDataParallel(Module):
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error.") "bucket slot. This is almost certainly an error.")
self.buckets[bucket_idx][bucket_loc] = param.grad.data if self.retain_allreduce_buffers:
self.buckets[bucket_idx][bucket_loc] = param.grad
else:
self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1 self.buckets_ready_size[bucket_idx] += 1
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket: if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event) self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.reduction_stream.wait_event(self.reduction_event)
with torch.cuda.stream(self.reduction_stream): self.next_bucket += 1
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
# Reversing upstream's logic here, because we constructed our buckets based on
self.next_bucket += 1 # the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
# Reversing upstream's logic here, because we constructed our buckets based on sorted_todo = sorted(self.ready_buckets_not_reduced)
# the order things were received during backward. for i in sorted_todo:
if len(self.ready_buckets_not_reduced) > 0: # Nothing can be reduced now
sorted_todo = sorted(self.ready_buckets_not_reduced) if i > self.next_bucket:
for i in sorted_todo: break
# Nothing can be reduced now elif i == self.next_bucket:
if i > self.next_bucket: self.allreduce_maybe_retain(self.buckets[i], i)
break self.ready_buckets_not_reduced.remove(i)
elif i == self.next_bucket: self.next_bucket += 1
self.allreduce_maybe_retain(self.buckets[i], i) else:
self.ready_buckets_not_reduced.remove(i) raise ValueError("i should always be >= next_bucket")
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
else: else:
self.ready_buckets_not_reduced.add(bucket_idx) self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs) result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self._disable_allreduce: if not self._disable_allreduce:
if not self.delay_allreduce: if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
...@@ -483,7 +570,7 @@ class DistributedDataParallel(Module): ...@@ -483,7 +570,7 @@ class DistributedDataParallel(Module):
# Forward has the authority to set needs_refresh to True, but only allreduce_params # Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False. # in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer. # Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or (len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])): any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True self.needs_refresh = True
...@@ -494,19 +581,59 @@ class DistributedDataParallel(Module): ...@@ -494,19 +581,59 @@ class DistributedDataParallel(Module):
self.tmp_buckets = [[], [], []] # [running half, float, double buckets] self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0] self.tmp_numels = [0, 0, 0]
self.bucket_sizes = [] self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {} self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else: else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])] # self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] # for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_communicators:
self.bucket_pgs = self.allreduce_communicators[0]
self.bucket_streams = self.allreduce_communicators[1]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)] self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0 self.next_bucket = 0
self.ready_buckets_not_reduced = set() self.ready_buckets_not_reduced = set()
self.active_params = param_list self.active_params = param_list
self.callback_queued = False self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result return result
...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda() >>> inp = torch.randn(10, 14, 14, 100).cuda()
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group self.process_group = process_group
self.channel_last = channel_last self.channel_last = channel_last
self.fuse_relu = fuse_relu
def _specify_process_group(self, process_group): def _specify_process_group(self, process_group):
self.process_group = process_group self.process_group = process_group
...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm):
def _specify_channel_last(self, channel_last): def _specify_channel_last(self, channel_last):
self.channel_last = channel_last self.channel_last = channel_last
def forward(self, input): def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing # if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last: if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm): ...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: else:
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last) return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp ...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False): def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw") torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous() input = input.contiguous()
world_size = 0 world_size = 0
...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function): ...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps) inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std) ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.process_group = process_group ctx.process_group = process_group
ctx.channel_last = channel_last ctx.channel_last = channel_last
ctx.world_size = world_size ctx.world_size = world_size
ctx.fuse_relu = fuse_relu
if channel_last: if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)
else: else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function): ...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path. # mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0) # mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0) # var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std = ctx.saved_tensors saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
process_group = ctx.process_group process_group = ctx.process_group
channel_last = ctx.channel_last channel_last = ctx.channel_last
world_size = ctx.world_size world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None fuse_relu = ctx.fuse_relu
grad_input = grad_z = grad_weight = grad_bias = None
if fuse_relu:
grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output? # TODO(jie): why do I have to clone here? life time of grad_output?
if channel_last: if channel_last:
...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function): ...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function):
else: else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu) grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]: if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None grad_weight = None
if weight is None or not ctx.needs_input_grad[2]: if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None grad_bias = None
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda( ...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda( void multi_tensor_axpby_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
...@@ -72,6 +85,8 @@ void multi_tensor_novograd_cuda( ...@@ -72,6 +85,8 @@ void multi_tensor_novograd_cuda(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
} }
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
// #include "ATen/Type.h" // #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h" #include "type_shim.h"
...@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel( ...@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel(
} }
} }
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
};
void fused_adam_cuda( void fused_adam_cuda(
at::Tensor & p, at::Tensor & p,
at::Tensor & p_copy, at::Tensor & p_copy,
...@@ -96,7 +187,7 @@ void fused_adam_cuda( ...@@ -96,7 +187,7 @@ void fused_adam_cuda(
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type //dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>; using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(), p.data<accscalar_t>(),
...@@ -112,7 +203,7 @@ void fused_adam_cuda( ...@@ -112,7 +203,7 @@ void fused_adam_cuda(
tsize, tsize,
(adamMode_t) mode, (adamMode_t) mode,
decay); decay);
) );
} else { } else {
using namespace at; using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel", DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
...@@ -135,3 +226,110 @@ void fused_adam_cuda( ...@@ -135,3 +226,110 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
THCudaCheck(cudaGetLastError());
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 512
#define ILP 4
/**
* Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
* lr : learning rate (scalar)
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template<int N, typename T_grad, typename T_weight>
struct SGDFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<N>& tl,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size;
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size;
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size;
at::Half *model_weights_out = nullptr;
if(N == 4)
{
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
// apply weight decay before momentum if necessary
if(wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if(momentum != 0.f)
{
if(!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
if(nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
else
incoming_grads[ii] = incoming_moms[ii];
}
// Apply WD after momentum if desired
if(wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if(N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum
if(momentum != 0.f)
mom_in[i] = incoming_moms[ii];
}
}
}
}
};
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if(num_tensors == 4)
for(int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::Half, at::Half>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
}
...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input); ...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::optional<at::Tensor> weight, const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift); const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias} // backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type; // grad_output/input should have identical data type;
...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, ...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu); const at::Tensor mean_dy_xmu);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance"); m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance"); m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc"); m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc"); m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc"); m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
} }
...@@ -591,6 +591,58 @@ template < ...@@ -591,6 +591,58 @@ template <
int PARALLEL_LOADS> int PARALLEL_LOADS>
__global__ void batchnorm_forward_c_last_kernel( __global__ void batchnorm_forward_c_last_kernel(
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int reduction_size,
const int stride,
const bool fuse_relu) {
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
}
}
// elementwise BN kernel
template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void relu_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean, const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std, const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight, const layerscalar_t* __restrict__ weight,
...@@ -619,9 +671,11 @@ __global__ void batchnorm_forward_c_last_kernel( ...@@ -619,9 +671,11 @@ __global__ void batchnorm_forward_c_last_kernel(
#pragma unroll #pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) { for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) { if (c_offset < stride && m_offset < reduction_size) {
out[address_base] = static_cast<scalar_t>( auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c if (z != NULL) {
); tmp += z[address_base];
}
out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
} }
m_offset += inner_loop_stride; m_offset += inner_loop_stride;
address_base += address_increment; address_base += address_increment;
...@@ -1147,10 +1201,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) { ...@@ -1147,10 +1201,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor batchnorm_forward_c_last_CUDA( at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input, const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::optional<at::Tensor> weight, const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) { const at::optional<at::Tensor> shift,
const bool fuse_relu) {
const auto stride = input.size(input.ndimension()-1); const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride; const auto reduction_size = input.numel() / stride;
...@@ -1170,13 +1226,15 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1170,13 +1226,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(), input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL, weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL, shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride,
fuse_relu);
); );
} else { } else {
if (weight.has_value()) { if (weight.has_value()) {
...@@ -1189,13 +1247,15 @@ at::Tensor batchnorm_forward_c_last_CUDA( ...@@ -1189,13 +1247,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER> batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(), input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(), mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(), inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL, weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL, shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(), out.data<scalar_t_0>(),
reduction_size, reduction_size,
stride); stride,
fuse_relu);
); );
} }
return out; return out;
...@@ -1351,3 +1411,66 @@ at::Tensor batchnorm_backward_c_last_CUDA( ...@@ -1351,3 +1411,66 @@ at::Tensor batchnorm_backward_c_last_CUDA(
return grad_input; return grad_input;
} }
at::Tensor relu_backward_c_last_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor out = at::empty_like(input);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
} else {
if (weight.has_value()) {
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
}
return out;
}
...@@ -25,54 +25,6 @@ try: ...@@ -25,54 +25,6 @@ try:
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
cudnn.benchmark = True
def fast_collate(batch): def fast_collate(batch):
imgs = [img[0] for img in batch] imgs = [img[0] for img in batch]
...@@ -90,24 +42,75 @@ def fast_collate(batch): ...@@ -90,24 +42,75 @@ def fast_collate(batch):
return tensor, targets return tensor, targets
best_prec1 = 0
args = parser.parse_args()
print("opt_level = {}".format(args.opt_level)) def parse():
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) model_names = sorted(name for name in models.__dict__
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
if args.deterministic: parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
cudnn.benchmark = False parser.add_argument('data', metavar='DIR',
cudnn.deterministic = True help='path to dataset')
torch.manual_seed(args.local_rank) parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
torch.set_printoptions(precision=10) choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
args = parser.parse_args()
return args
def main(): def main():
global best_prec1, args global best_prec1, args
args = parse()
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
......
...@@ -72,6 +72,7 @@ if "--cuda_ext" in sys.argv: ...@@ -72,6 +72,7 @@ if "--cuda_ext" in sys.argv:
ext_modules.append( ext_modules.append(
CUDAExtension(name='amp_C', CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp', sources=['csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_sgd_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu', 'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu', 'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu', 'csrc/multi_tensor_l2norm_kernel.cu',
...@@ -104,6 +105,58 @@ if "--cuda_ext" in sys.argv: ...@@ -104,6 +105,58 @@ if "--cuda_ext" in sys.argv:
'-O3', '-O3',
'--use_fast_math'] + version_ge_1_1})) '--use_fast_math'] + version_ge_1_1}))
if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append(
CUDAExtension(name='bnp',
sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
extra_compile_args={'cxx': [] + version_ge_1_1,
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'-gencode',
'arch=compute_70,code=sm_70'] + version_ge_1_1}))
if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=['csrc'],
extra_compile_args={'cxx': ['-O3'] + version_ge_1_1,
'nvcc':['-O3'] + version_ge_1_1}))
setup( setup(
name='apex', name='apex',
version='0.1', version='0.1',
......
...@@ -137,26 +137,6 @@ class TestTensorCasts(unittest.TestCase): ...@@ -137,26 +137,6 @@ class TestTensorCasts(unittest.TestCase):
fn = lambda x: x.sum() fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestDisabledCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=False)
common_init(self)
def test_disabled_linear(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
input_shape = (self.b, self.h)
for fn in [m, f]:
x = torch.randn(input_shape, dtype=torch.float).requires_grad_()
y = fn(x)
self.assertEqual(y.type(), FLOAT)
y.sum().backward()
self.assertEqual(x.grad.type(), FLOAT)
x = torch.randn(input_shape, dtype=torch.half).requires_grad_()
self.assertRaises(RuntimeError, fn, x)
# TODO: maybe more tests on disabled casting? # TODO: maybe more tests on disabled casting?
if __name__ == '__main__': if __name__ == '__main__':
......
import unittest
import functools as ft
import itertools as it
from apex import amp
from apex.amp import _amp_state
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
disabled = False
from apex.optimizers import FusedSGD as FusedSGD
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True
class MyModel(torch.nn.Module):
def __init__(self, unique):
super(MyModel, self).__init__()
self.weight0 = Parameter(unique +
torch.arange(2, device='cuda', dtype=torch.float32))
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
@staticmethod
def ops(input, weight0, weight1):
return ((input*(weight0.float()))*(weight1.float())).sum()
def forward(self, input):
return self.ops(input, self.weight0, self.weight1)
# Abandon all hope, ye who enter here.
# This is hands down the ugliest code I have ever written, but it succeeds in testing
# multiple models/optimizers/losses fairly thoroughly. Many of the different test cases
# require slightly divergent code in a way that seems near-impossible to genericize into a simple
# cross product or nested loops.
class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], optimizer = amp.initialize(
[model0, model1],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()] +
[param.grad.data.clone() for param in model2.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 2)
elif which_backward == 1:
which_models = (1, 2)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], optimizer = amp.initialize(
[model0, model1, model2],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 1:
inj_model = model1
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25)
# Don't do it like this: reference_grads = [[]]*5
# because then it creates a list of 5 references to the same "[]" and appending
# to any of them effectively makes you append to all of them, which multiplies
# the resulting size of reference_grads by 5x and needless to say makes the test fail.
reference_grads = [[], [], [], [], []]
final_params = [None, None, None, None, None]
for i in range(2):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer0.step()
optimizer1.step()
final_params[0] = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
def what_got_skipped(which_iter, which_backward):
if which_iter == 0 and which_backward == 0:
return 1
if which_iter == 0 and which_backward == 1:
return 2
if which_iter == 1 and which_backward == 0:
return 3
if which_iter == 1 and which_backward == 1:
return 4
return 0
for which_iter in (0,1):
for which_backward in (0,1):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25)
for i in range(3):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
if i != which_iter:
reference_grads[what_got_skipped(which_iter, which_backward)].append(
[param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
if i == which_iter:
if which_backward == 0:
optimizer1.step()
else:
optimizer0.step()
else:
optimizer0.step()
optimizer1.step()
final_params[what_got_skipped(which_iter, which_backward)] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], [optimizer0, optimizer1] = amp.initialize(
[model0, model1],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
if i != inject_inf:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5)
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25)
# Again, can't do this: reference_grads = [[]]*9
reference_grads = [[], [], [], [], [], [], [], [], []]
final_params = [None, None, None, None, None, None, None, None, None]
for i in range(2):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
loss0.backward()
loss1.backward()
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer0.step()
optimizer1.step()
final_params[0] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
def what_got_skipped(which_iter, which_backward, which_model):
if which_iter == 0:
if which_backward == 0:
if which_model == 0:
return 1
if which_model == 1:
return 2
if which_backward == 1:
if which_model == 2:
return 3
if which_model == 1:
return 4
if which_iter == 1:
if which_backward == 0:
if which_model == 0:
return 5
if which_model == 1:
return 6
if which_backward == 1:
if which_model == 2:
return 7
if which_model == 1:
return 8
return 0
for which_iter in (0,1):
for which_backward in (0,1):
if which_backward == 0:
which_models = (0,1)
if which_backward == 1:
which_models = (2,1)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5)
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25)
for i in range(3):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
loss0.backward()
loss1.backward()
if i != which_iter:
reference_grads[what_got_skipped(which_iter,
which_backward, which_model)].append(
[param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
if i == which_iter:
if which_backward == 0:
# if which_model == 0:
optimizer1.step()
# if which_model == 1:
# optimizer1.step()
if which_backward == 1:
# if which_model == 2:
# optimizer0.step()
# if which_model == 1:
continue
else:
optimizer0.step()
optimizer1.step()
final_params[what_got_skipped(which_iter, which_backward, which_model)] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 1)
elif which_backward == 1:
which_models = (2, 1)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
[model0, model1, model2],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 2:
inj_model = model2
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf,
which_backward, which_model)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
if __name__ == '__main__':
unittest.main()
...@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase): ...@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, adam_option): def gen_param_optim(self, tensors, ref_adam_option, tst_adam_option=None):
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone())) ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adam(ref_param, **adam_option) ref_optim = torch.optim.Adam(ref_param, **ref_adam_option)
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **adam_option) if tst_adam_option:
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **tst_adam_option)
else:
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **ref_adam_option)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
...@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase): ...@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase):
def get_max_diff(self, ref_param, tst_param): def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0 max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item() max_abs_diff_p = (p_ref - p_tst.type(p_ref.type())).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() max_rel_diff_p = ((p_ref - p_tst.type(p_ref.type())) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
...@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase): ...@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_multi_tensor(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
ref_adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tst_adam_option = dict(ref_adam_option, **{'use_mt':True})
tensors = []
fp16_params = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
fp16_params.append(torch.nn.Parameter(tensors[-1].clone().half()))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, ref_adam_option, tst_adam_option)
for i in range(self.iters):
half_grads = self.gen_mixed_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step(grads=half_grads, output_params=fp16_params)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
fp16_params)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__': if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
......
#!/bin/bash #!/bin/bash
DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/" # DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
# DATADIR="/opt/home/apex/examples/imagenet/" # DATADIR="/opt/home/apex/examples/imagenet/"
cp ../common/* . cp ../common/* .
bash run_test.sh single_gpu $1 $DATADIR yes bash run_test.sh single_gpu $1
...@@ -35,8 +35,9 @@ class Model(Module): ...@@ -35,8 +35,9 @@ class Model(Module):
model = Model() model = Model()
# model = DDP(model, message_size=1, gradient_predivide_factor=8.0) # model = DDP(model, message_size=1, gradient_predivide_factor=8.0)
model = DDP(model, delay_allreduce=True) # model = DDP(model, delay_allreduce=True)
# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b]) # model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)
x = torch.cuda.FloatTensor(4096*4096) x = torch.cuda.FloatTensor(4096*4096)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment