Commit 6af5980e authored by Michael Carilli's avatar Michael Carilli
Browse files

Merging in FusedAdam treatment

parents 16a3bdf3 7aad54f7
# Introduction # Introduction
This repository holds NVIDIA-maintained utilities to streamline This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch. mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually. Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to The intention of Apex is to make up-to-date utilities available to
users as quickly as possible. users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex) ## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
...@@ -29,7 +29,7 @@ different flags to `amp.initialize`. ...@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
## 2. Distributed Training ## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to `apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training, `torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library. optimized for NVIDIA's NCCL communication library.
......
...@@ -114,29 +114,13 @@ def check_optimizers(optimizers): ...@@ -114,29 +114,13 @@ def check_optimizers(optimizers):
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) + raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n" "The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n" "instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n" "optimizers (FusedAdam or FusedSGD). \n"
"soon). You should not manually wrap your optimizer in either \n" "You should not manually wrap your optimizer in either \n"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n" "apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
"amp.initialize will take care of that for you (if necessary) based \n" "amp.initialize will take care of that for you (if necessary) based \n"
"on the specified opt_level (and optional overridden properties).") "on the specified opt_level (and optional overridden properties).")
def wrap_fused_adam(optimizer, properties):
msg = 'Currently, the usage of FusedAdam is restricted to '\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert properties.master_weights is True, msg
assert properties.cast_model_type is torch.float16, msg
assert (properties.keep_batchnorm_fp32 is False or
properties.keep_batchnorm_fp32 is None), msg
if properties.loss_scale == "dynamic":
return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
else:
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None): def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init from .amp import init as amp_init
...@@ -163,7 +147,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -163,7 +147,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if not _amp_state.allow_incoming_model_not_fp32: if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models) check_params_fp32(models)
check_optimizers(optimizers) check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can # In the future, when FP16_Optimizer can be deprecated and master weights can
...@@ -196,7 +180,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -196,7 +180,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model.forward = patch_forward(model.forward) model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors # State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers: for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict()) optimizer.load_state_dict(optimizer.state_dict())
elif cast_model_outputs is not None: elif cast_model_outputs is not None:
...@@ -212,11 +196,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -212,11 +196,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model.forward = patch_forward(model.forward) model.forward = patch_forward(model.forward)
for i, optimizer in enumerate(optimizers): for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass optimizers[i] = _process_optimizer(optimizer, properties)
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
else:
optimizers[i] = _process_optimizer(optimizer, properties)
_amp_state.loss_scalers = [] _amp_state.loss_scalers = []
for _ in range(num_losses): for _ in range(num_losses):
......
...@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params ...@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print from ._amp_state import maybe_print
import torch import torch
from ..optimizers import FusedAdam
class AmpOptimizerState(object): class AmpOptimizerState(object):
...@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self): ...@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self):
self.load_state_dict(self.state_dict()) self.load_state_dict(self.state_dict())
def post_backward_models_are_masters(scaler, params, stashed_grads):
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def prepare_backward_with_master_weights(self): def prepare_backward_with_master_weights(self):
stash = self._amp_stash stash = self._amp_stash
...@@ -106,7 +141,7 @@ def post_backward_with_master_weights(self, scaler): ...@@ -106,7 +141,7 @@ def post_backward_with_master_weights(self, scaler):
if fp16_param.grad is None and fp32_param.grad is not None: if fp16_param.grad is None and fp32_param.grad is not None:
continue continue
elif fp16_param.grad is not None and fp32_param.grad is None: elif fp16_param.grad is not None and fp32_param.grad is None:
fp32_param.grad = torch.empty_like(fp32_param) fp32_param.grad = torch.empty_like(fp32_param)
fp16_grads_needing_unscale.append(fp16_param.grad) fp16_grads_needing_unscale.append(fp16_param.grad)
new_fp32_grads.append(fp32_param.grad) new_fp32_grads.append(fp32_param.grad)
elif fp16_param.grad is not None and fp32_param.grad is not None: elif fp16_param.grad is not None and fp32_param.grad is not None:
...@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler): ...@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads) preexisting_fp32_grads)
# fp32 params can be treated as they would be in the "no_master_weights" case. # fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale = [] post_backward_models_are_masters(
grads_needing_unscale_with_stash = [] scaler,
stashed = [] stash.all_fp32_from_fp32_params,
for param, stashed_grad in zip(stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash)
stash.all_fp32_from_fp32_grad_stash):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None:
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
stash.all_fp32_from_fp32_grad_stash[i] = None
def lazy_init_no_master_weights(self): def lazy_init_no_master_weights(self):
...@@ -176,7 +184,7 @@ def lazy_init_no_master_weights(self): ...@@ -176,7 +184,7 @@ def lazy_init_no_master_weights(self):
raise TypeError("Optimizer's parameters must be either " raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params] stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params] stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
...@@ -206,37 +214,56 @@ def post_backward_no_master_weights(self, scaler): ...@@ -206,37 +214,56 @@ def post_backward_no_master_weights(self, scaler):
(stash.all_fp32_params, stash.all_fp32_grad_stash)) (stash.all_fp32_params, stash.all_fp32_grad_stash))
for params, stashed_grads in split_types: for params, stashed_grads in split_types:
# This is a lot of python overhead... post_backward_models_are_masters(scaler, params, stashed_grads)
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0: def prepare_backward_with_master_weights_fused(self):
scaler.unscale_with_stashed( stash = self._amp_stash
grads_needing_unscale_with_stash,
stashed, if not stash.lazy_init_called:
grads_needing_unscale_with_stash) self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
# Clear the stash.
for i in range(len(stashed_grads)): def post_backward_with_master_weights_fused(self, scaler):
stashed_grads[i] = None stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_fused(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
def _master_params_to_model_params(self): def _master_params_to_model_params(self):
...@@ -274,6 +301,7 @@ def _process_optimizer(optimizer, properties): ...@@ -274,6 +301,7 @@ def _process_optimizer(optimizer, properties):
if multi_tensor_applier.available: if multi_tensor_applier.available:
import amp_C import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]); optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights: if properties.master_weights:
...@@ -286,7 +314,8 @@ def _process_optimizer(optimizer, properties): ...@@ -286,7 +314,8 @@ def _process_optimizer(optimizer, properties):
old_step = optimizer.step old_step = optimizer.step
def new_step(self): def new_step(self):
retval = old_step() retval = old_step()
self._master_params_to_model_params() if not isinstance(self, FusedAdam):
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad() # Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params: for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None param.grad = None
...@@ -313,19 +342,29 @@ def _process_optimizer(optimizer, properties): ...@@ -313,19 +342,29 @@ def _process_optimizer(optimizer, properties):
param.grad = None param.grad = None
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer) optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
optimizer._prepare_amp_backward = types.MethodType( if isinstance(optimizer, FusedAdam):
prepare_backward_with_master_weights, optimizer) optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_fused, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer) post_backward_with_master_weights_fused, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer)
else: else:
optimizer._lazy_init_maybe_master_weights = types.MethodType( optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_no_master_weights, optimizer) lazy_init_no_master_weights, optimizer)
optimizer._prepare_amp_backward = types.MethodType( if isinstance(optimizer, FusedAdam):
prepare_backward_no_master_weights, optimizer) optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_fused, optimizer)
optimizer._post_amp_backward = types.MethodType( optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer) post_backward_no_master_weights_fused, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer)
return optimizer return optimizer
...@@ -6,8 +6,6 @@ from . import utils ...@@ -6,8 +6,6 @@ from . import utils
from .opt import OptimWrapper from .opt import OptimWrapper
from .scaler import LossScaler from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print from ._amp_state import _amp_state, master_params, maybe_print
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls. # There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
...@@ -82,13 +80,8 @@ def scale_loss(loss, ...@@ -82,13 +80,8 @@ def scale_loss(loss,
if isinstance(optimizers, torch.optim.Optimizer): if isinstance(optimizers, torch.optim.Optimizer):
optimizers = [optimizers] optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API... loss_scaler = _amp_state.loss_scalers[loss_id]
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler. loss_scale = loss_scaler.loss_scale()
if isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scale = optimizers.cur_scale
else:
loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights) if ((not _amp_state.opt_properties.master_weights)
and (not loss_scaler.dynamic) and (not loss_scaler.dynamic)
...@@ -113,8 +106,8 @@ def scale_loss(loss, ...@@ -113,8 +106,8 @@ def scale_loss(loss,
for optimizer in optimizers: for optimizer in optimizers:
optimizer._amp_stash.params_have_scaled_gradients = True optimizer._amp_stash.params_have_scaled_gradients = True
else: else:
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods. # FusedAdam and FusedSGD may take care of unscaling as part of their step() methods.
if not isinstance(optimizers, FP16_Optimizer_for_fused): # if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler.clear_overflow_state() loss_scaler.clear_overflow_state()
for optimizer in optimizers: for optimizer in optimizers:
optimizer._post_amp_backward(loss_scaler) optimizer._post_amp_backward(loss_scaler)
......
...@@ -2,6 +2,8 @@ import types ...@@ -2,6 +2,8 @@ import types
import torch import torch
import importlib import importlib
from ..multi_tensor_apply import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer): class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
...@@ -25,6 +27,8 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -25,6 +27,8 @@ class FusedAdam(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -35,10 +39,21 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -35,10 +39,21 @@ class FusedAdam(torch.optim.Optimizer):
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False): weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -66,6 +81,12 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -66,6 +81,12 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms
if grads is None: if grads is None:
grads_group = [None]*len(self.param_groups) grads_group = [None]*len(self.param_groups)
# backward compatibility # backward compatibility
...@@ -105,6 +126,12 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -105,6 +126,12 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
...@@ -130,18 +157,43 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -130,18 +157,43 @@ class FusedAdam(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data, if self._use_multi_tensor:
out_p, pl = [p.data, exp_avg, exp_avg_sq, grad]
exp_avg, if output_param is not None:
exp_avg_sq, pl.append(out_p)
grad,
group['lr'], for tl, t in zip(tensorlists, pl):
beta1, tl.append(t)
beta2, else:
group['eps'], fused_adam_cuda.adam(p.data,
combined_scale, out_p,
state['step'], exp_avg,
self.eps_mode, exp_avg_sq,
bias_correction, grad,
group['weight_decay']) group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
...@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None): ...@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
if call is dist.all_reduce: if call is dist.all_reduce:
coalesced /= dist.get_world_size() coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, unflatten(coalesced, bucket)): for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced) buf.copy_(synced)
...@@ -54,7 +54,7 @@ def split_half_float_double(tensors): ...@@ -54,7 +54,7 @@ def split_half_float_double(tensors):
for i, dtype in enumerate(dtypes): for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype] bucket = [t for t in tensors if t.type() == dtype]
if bucket: if bucket:
buckets.append(bucket) buckets.append(bucket)
return buckets return buckets
def split_by_type(tensors): def split_by_type(tensors):
...@@ -69,12 +69,12 @@ def split_by_type(tensors): ...@@ -69,12 +69,12 @@ def split_by_type(tensors):
# flat_dist_call organizes 'tensors' by type. # flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None): def flat_dist_call(tensors, call, extra_args=None):
buckets = split_by_type(tensors) buckets = split_by_type(tensors)
for tp in buckets: for tp in buckets:
bucket = buckets[tp] bucket = buckets[tp]
apply_flat_dist_call(bucket, call, extra_args) apply_flat_dist_call(bucket, call, extra_args)
def extract_tensors(maybe_tensor, tensor_list): def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor) tensor_list.append(maybe_tensor)
...@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list): ...@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list):
except TypeError: except TypeError:
return return
class Reducer(object): class Reducer(object):
""" """
:class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters :class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters
...@@ -93,13 +93,13 @@ class Reducer(object): ...@@ -93,13 +93,13 @@ class Reducer(object):
Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce
parameters during ``backward()``. parameters during ``backward()``.
Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually. Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.
This enables, for example, delaying the allreduce to be carried out every This enables, for example, delaying the allreduce to be carried out every
several iterations instead of every single iteration. several iterations instead of every single iteration.
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
over the number of participating processes. over the number of participating processes.
:class:`Reducer` is designed to work with the upstream launch utility script :class:`Reducer` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs. When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
...@@ -109,7 +109,7 @@ class Reducer(object): ...@@ -109,7 +109,7 @@ class Reducer(object):
Args: Args:
module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training. module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.
""" """
def __init__(self, module_or_grads_list): def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module): if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list self.module = module_or_grads_list
...@@ -119,26 +119,26 @@ class Reducer(object): ...@@ -119,26 +119,26 @@ class Reducer(object):
self.module = None self.module = None
self.grads = [] self.grads = []
extract_tensors(module_or_grads_list, self.grads) extract_tensors(module_or_grads_list, self.grads)
def reduce(self): def reduce(self):
if self.module: if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce) flat_dist_call(grads, dist.all_reduce)
else: else:
flat_dist_call(self.grads, dist.all_reduce) flat_dist_call(self.grads, dist.all_reduce)
class DistributedDataParallel(Module): class DistributedDataParallel(Module):
""" """
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables :class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are
allreduced and averaged over processes during ``backward()``. allreduced and averaged over processes during ``backward()``.
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by :class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
overlapping communication with computation during ``backward()`` and bucketing smaller gradient overlapping communication with computation during ``backward()`` and bucketing smaller gradient
transfers to reduce the total number of transfers required. transfers to reduce the total number of transfers required.
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script :class:`DistributedDataParallel` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``. ``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs. When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model. It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
...@@ -161,20 +161,22 @@ class DistributedDataParallel(Module): ...@@ -161,20 +161,22 @@ class DistributedDataParallel(Module):
""" """
def __init__(self, def __init__(self,
module, module,
message_size=10000000, message_size=10000000,
delay_allreduce=False, delay_allreduce=False,
shared_param=None, shared_param=None,
allreduce_trigger_params=None, allreduce_trigger_params=None,
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
allreduce_different_streams=False,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0, gradient_predivide_factor=1.0,
gradient_average_split_factor=None): gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around # Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and
# https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86 # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86
if hasattr(dist, "get_backend"): if hasattr(dist, "get_backend"):
...@@ -184,13 +186,20 @@ class DistributedDataParallel(Module): ...@@ -184,13 +186,20 @@ class DistributedDataParallel(Module):
else: else:
self.backend_enum_holder = dist.Backend self.backend_enum_holder = dist.Backend
else: else:
self._backend = dist._backend self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
if allreduce_different_streams and delay_allreduce:
raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
self.allreduce_different_streams = allreduce_different_streams
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
if gradient_average_split_factor is not None: if gradient_average_split_factor is not None:
print("Warning: gradient_average_split_factor has been renamed to gradient_predivide_factor. For now, gradient_average_split_factor will also work, but please update to gradient_predivide_factor instead.") print("Warning: gradient_average_split_factor has been renamed to gradient_predivide_factor. For now, gradient_average_split_factor will also work, but please update to gradient_predivide_factor instead.")
...@@ -206,25 +215,27 @@ class DistributedDataParallel(Module): ...@@ -206,25 +215,27 @@ class DistributedDataParallel(Module):
self.custom_allreduce_triggers = False self.custom_allreduce_triggers = False
if allreduce_trigger_params is not None: if allreduce_trigger_params is not None:
if delay_allreduce: if delay_allreduce:
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.") raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
self.custom_allreduce_triggers = True self.custom_allreduce_triggers = True
self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params]) self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])
self.delay_allreduce = delay_allreduce self.delay_allreduce = delay_allreduce
self.message_size = message_size self.message_size = message_size
self.reduction_stream = torch.cuda.Stream() self.main_stream = torch.cuda.current_stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.bucket_streams = []
self.bucket_events = []
self.module = module self.module = module
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.active_params = [] self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0, self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1, "torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2} "torch.cuda.DoubleTensor" : 2}
...@@ -241,19 +252,25 @@ class DistributedDataParallel(Module): ...@@ -241,19 +252,25 @@ class DistributedDataParallel(Module):
def __setstate__(self, state): def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state) super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream() if allreduce_different_streams and delay_allreduce:
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) raise ValueError("allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self): def __getstate__(self):
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL: if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream'] del attrs['self.bucket_streams']
del attrs['self.reduction_event'] del attrs['self.bucket_events']
return attrs return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes # Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match. # regenerate their bucket structures to match.
def sync_bucket_structure(self): def sync_bucket_structure(self):
# Append leftover buckets # Append leftover buckets
for tmp_bucket in self.tmp_buckets: for tmp_bucket in self.tmp_buckets:
...@@ -263,8 +280,8 @@ class DistributedDataParallel(Module): ...@@ -263,8 +280,8 @@ class DistributedDataParallel(Module):
self.num_buckets = len(self.active_i_buckets) self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets] self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] + info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes + self.bucket_sizes +
list(chain(*self.active_i_buckets))) list(chain(*self.active_i_buckets)))
dist.broadcast(info_tensor, 0) dist.broadcast(info_tensor, 0)
...@@ -272,27 +289,27 @@ class DistributedDataParallel(Module): ...@@ -272,27 +289,27 @@ class DistributedDataParallel(Module):
info = [int(entry) for entry in info_tensor] info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0] self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1] self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to # Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well. # keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])] self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:] flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0 flat_i = 0
for bucket_idx in range(self.num_buckets): for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]): for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i] param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc) self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1 flat_i += 1
def create_hooks(self): def create_hooks(self):
# Fallback hook that's only called at the end of backward. # Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the # Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later # bucket structure that will be used to overlap communication with computation in later
# iterations. # iterations.
def allreduce_params(): def allreduce_params():
...@@ -307,9 +324,10 @@ class DistributedDataParallel(Module): ...@@ -307,9 +324,10 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue(): def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event) for stream, event in zip(self.bucket_streams, self.bucket_events):
torch.cuda.current_stream().wait_event(self.reduction_event) stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off # Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets: if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format( raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
...@@ -319,7 +337,7 @@ class DistributedDataParallel(Module): ...@@ -319,7 +337,7 @@ class DistributedDataParallel(Module):
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes): for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected: if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.") raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = [] self.grad_accs = []
for param in self.module.parameters(): for param in self.module.parameters():
...@@ -329,6 +347,9 @@ class DistributedDataParallel(Module): ...@@ -329,6 +347,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if self.delay_allreduce or self.needs_refresh: if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between # TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True? # each forward, e.g., backward passes with retain_graph=True?
...@@ -339,8 +360,8 @@ class DistributedDataParallel(Module): ...@@ -339,8 +360,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately. # Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()] current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i) self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False ship_tmp_bucket = False
if self.custom_allreduce_triggers: if self.custom_allreduce_triggers:
...@@ -357,81 +378,136 @@ class DistributedDataParallel(Module): ...@@ -357,81 +378,136 @@ class DistributedDataParallel(Module):
self.active_i_buckets.append(self.tmp_buckets[current_type]) self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = [] self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0 self.tmp_numels[current_type] = 0
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True self.callback_queued = True
else: else:
if not self.callback_queued: if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue) Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True self.callback_queued = True
self.comm_ready_buckets(param) self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
wrapper(param) wrapper(param)
def allreduce_bucket(self, bucket):
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):
tensor = flatten(bucket) tensor = flatten(bucket)
tensor_to_allreduce = tensor if force_default_stream:
bucket_stream = self.main_stream
else:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
if self.allreduce_always_fp32: if self.allreduce_different_streams and self.bucket_pgs:
tensor_to_allreduce = tensor.float() dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.gradient_predivide_factor != 1.0: if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) tensor.copy_(tensor_to_allreduce)
dist.all_reduce(tensor_to_allreduce) if not self.retain_allreduce_buffers:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(tensor, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced)
if self.gradient_average: # I think we actually do need this here. After allreduce_bucket returns, tensor will
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size) # eventually go out of scope and die, at which point it could otherwise be freed for
# further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.
tensor.record_stream(bucket_stream)
# torch.cuda.synchronize()
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
return tensor return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket) def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):
allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None: if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.") "allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
else: for view, grad in zip(unflatten(allreduced, bucket), bucket):
if multi_tensor_applier.available: grad.data = view
multi_tensor_applier( # for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
self.multi_tensor_scale, # buf.copy_(synced)
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self): def allreduce_fallback(self):
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False, # If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the # this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless. # training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))] self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets): for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i) allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)
def comm_ready_buckets(self, param): def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR. # Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream()) # self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)] bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
...@@ -439,39 +515,46 @@ class DistributedDataParallel(Module): ...@@ -439,39 +515,46 @@ class DistributedDataParallel(Module):
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error.") "bucket slot. This is almost certainly an error.")
self.buckets[bucket_idx][bucket_loc] = param.grad.data if self.retain_allreduce_buffers:
self.buckets[bucket_idx][bucket_loc] = param.grad
else:
self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1 self.buckets_ready_size[bucket_idx] += 1
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket: if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event) self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.reduction_stream.wait_event(self.reduction_event)
with torch.cuda.stream(self.reduction_stream): self.next_bucket += 1
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
# Reversing upstream's logic here, because we constructed our buckets based on
self.next_bucket += 1 # the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
# Reversing upstream's logic here, because we constructed our buckets based on sorted_todo = sorted(self.ready_buckets_not_reduced)
# the order things were received during backward. for i in sorted_todo:
if len(self.ready_buckets_not_reduced) > 0: # Nothing can be reduced now
sorted_todo = sorted(self.ready_buckets_not_reduced) if i > self.next_bucket:
for i in sorted_todo: break
# Nothing can be reduced now elif i == self.next_bucket:
if i > self.next_bucket: self.allreduce_maybe_retain(self.buckets[i], i)
break self.ready_buckets_not_reduced.remove(i)
elif i == self.next_bucket: self.next_bucket += 1
self.allreduce_maybe_retain(self.buckets[i], i) else:
self.ready_buckets_not_reduced.remove(i) raise ValueError("i should always be >= next_bucket")
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
else: else:
self.ready_buckets_not_reduced.add(bucket_idx) self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs) result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self.delay_allreduce: if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
...@@ -479,7 +562,7 @@ class DistributedDataParallel(Module): ...@@ -479,7 +562,7 @@ class DistributedDataParallel(Module):
# Forward has the authority to set needs_refresh to True, but only allreduce_params # Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False. # in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer. # Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or (len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])): any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True self.needs_refresh = True
...@@ -490,19 +573,53 @@ class DistributedDataParallel(Module): ...@@ -490,19 +573,53 @@ class DistributedDataParallel(Module):
self.tmp_buckets = [[], [], []] # [running half, float, double buckets] self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0] self.tmp_numels = [0, 0, 0]
self.bucket_sizes = [] self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {} self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else: else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_buckets)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_buckets)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_buckets)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)] self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0 self.next_bucket = 0
self.ready_buckets_not_reduced = set() self.ready_buckets_not_reduced = set()
self.active_params = param_list self.active_params = param_list
self.callback_queued = False self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result return result
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
} }
...@@ -9,6 +9,10 @@ ...@@ -9,6 +9,10 @@
#include "ATen/Type.h" #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h> #include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h" #include "type_shim.h"
...@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel( ...@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel(
} }
} }
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
};
void fused_adam_cuda( void fused_adam_cuda(
at::Tensor & p, at::Tensor & p,
at::Tensor & p_copy, at::Tensor & p_copy,
...@@ -135,3 +226,110 @@ void fused_adam_cuda( ...@@ -135,3 +226,110 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].type().scalarType() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
}
} else {
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
} else {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
}
}
THCudaCheck(cudaGetLastError());
}
...@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase): ...@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def gen_param_optim(self, tensors, adam_option): def gen_param_optim(self, tensors, ref_adam_option, tst_adam_option=None):
ref_param = [] ref_param = []
tst_param = [] tst_param = []
for tensor in tensors: for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone())) ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone())) tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adam(ref_param, **adam_option) ref_optim = torch.optim.Adam(ref_param, **ref_adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option) if tst_adam_option:
tst_optim = apex.optimizers.FusedAdam(tst_param, **tst_adam_option)
else:
tst_optim = apex.optimizers.FusedAdam(tst_param, **ref_adam_option)
return (ref_param, tst_param, ref_optim, tst_optim) return (ref_param, tst_param, ref_optim, tst_optim)
...@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase): ...@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase):
def get_max_diff(self, ref_param, tst_param): def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0 max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param): for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item() max_abs_diff_p = (p_ref - p_tst.type(p_ref.type())).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item() max_rel_diff_p = ((p_ref - p_tst.type(p_ref.type())) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
...@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase): ...@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff) self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff) self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_multi_tensor(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
ref_adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tst_adam_option = dict(ref_adam_option, **{'use_mt':True})
tensors = []
fp16_params = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
fp16_params.append(torch.nn.Parameter(tensors[-1].clone().half()))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, ref_adam_option, tst_adam_option)
for i in range(self.iters):
half_grads = self.gen_mixed_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step(grads=half_grads, output_params=fp16_params)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
fp16_params)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__': if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment