Commit 9eab1ac3 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merge branch 'master' of https://github.com/NVIDIA/apex

parents 559141e8 2361a646
from . import compat, utils, wrap
from . import compat, rnn_compat, utils, wrap
from .handle import AmpHandle, NoOpHandle
from .lists import functional_overrides, torch_overrides, tensor_overrides
......@@ -73,7 +73,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
# 0.5) Force-promote for user-annotated functions
for mod, fn in _USER_PROMOTE_REGISTRY:
wrap.promote(mod, fn, verbose)
wrap.promote(mod, fn, handle, verbose)
_USER_PROMOTE_REGISTRY.clear()
# 1) Force-{fp16, fp32} on white- / black-list functions
......@@ -107,7 +107,7 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
for promote_mod, (list_name, promote_fn) in itertools.product(promote_modules,
promote_table):
for fn in getattr(promote_mod, list_name):
promote_fn(promote_mod.MODULE, fn, verbose)
promote_fn(promote_mod.MODULE, fn, handle, verbose)
# 2.5) Pre-0.4, add blacklist methods directly to HalfTensor and FloatTensor types
if compat.tensor_is_float_tensor():
......@@ -115,48 +115,49 @@ def init(enabled=True, enable_caching=True, verbose=False, allow_banned=False):
torch.cuda.HalfTensor],
promote_table):
for fn in getattr(tensor_overrides, list_name):
promote_fn(cls, fn, verbose)
promote_fn(cls, fn, handle, verbose)
# 3) For any in-place version of a blacklist function, error if any input is fp16.
# NB: this is overly conservative.
for fn in utils.as_inplace(torch_overrides.FP32_FUNCS):
wrap.err_if_any_half(torch_overrides.MODULE, fn)
wrap.err_if_any_half(torch_overrides.MODULE, fn, handle)
# 3.5) For any in-place blacklist method, error if called on fp16 tensor
for fn in utils.as_inplace(tensor_overrides.FP32_FUNCS):
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, verbose)
wrap.err_if_arg0_half(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, verbose)
wrap.err_if_arg0_half(torch.cuda.HalfTensor, fn, handle, verbose)
# 4) For other in-place methods, match the type of self tensor
for fn in utils.as_inplace(itertools.chain(
tensor_overrides.FP16_FUNCS,
tensor_overrides.CASTS)):
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, verbose)
wrap.promote_match_arg0(tensor_overrides.MODULE, fn, handle, verbose)
if compat.tensor_is_float_tensor():
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, verbose)
# 5) Special handling to whitelist RNN cell backend impls.
for fn in ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']:
wrap.cached_cast(torch.nn.backends.thnn.backend, fn, utils.maybe_half,
handle, try_caching=True, verbose=verbose)
# 5.5) Extra-special handling of RNN backend
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', verbose)
# And even more special handling of `backward` for fused gru / lstm
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
# 6) Place error+print message on banned functions
if not allow_banned:
for fn, err_msg in functional_overrides.BANNED_FUNCS:
wrap.err_if_any_half(functional_overrides.MODULE, fn, err_msg)
wrap.promote_match_arg0(torch.cuda.HalfTensor, fn, handle, verbose)
wrap.promote_match_arg0(torch.cuda.FloatTensor, fn, handle, verbose)
# 5) RNNs + RNN cells are whitelisted specially
if rnn_compat.has_old_rnns():
wrap.rnn_cast(torch.nn.backends.thnn.backend, 'RNN', handle, verbose)
if not rnn_compat.has_old_rnns():
# Patch in our own indirection of `_VF` in modules/rnn s.t. it is mutable.
torch.nn.modules.rnn._VF = rnn_compat.VariableFunctionsShim()
# Wrap all the rnns
for x in rnn_compat.RNN_NAMES:
wrap.new_rnn_cast(x.upper(), handle, verbose)
# Wrap all the RNN cells
rnn_compat.whitelist_rnn_cells(handle, verbose)
# 6) Place error+print message on banned functions.
# Or, if allow_banned, then cast to FP32.
for fn, err_msg in functional_overrides.BANNED_FUNCS:
if allow_banned:
wrap.cached_cast(functional_overrides.MODULE, fn, utils.maybe_float,
handle, try_caching=True, verbose=verbose)
else:
wrap.err_if_any_half(functional_overrides.MODULE, fn, handle, err_msg)
_DECORATOR_HANDLE = handle
return handle
......@@ -2,6 +2,7 @@ import contextlib
import logging
import warnings
from . import utils
from .opt import OptimWrapper
from .scaler import LossScaler
......@@ -12,6 +13,7 @@ class AmpHandle(object):
self._cache = dict()
self._default_scaler = LossScaler()
self._is_active = True
self._all_wrappers = []
def is_active(self):
return self._is_active
......@@ -63,6 +65,15 @@ class AmpHandle(object):
def _clear_cache(self):
self._cache.clear()
# Experimental support for saving / restoring uncasted versions of functions
def _save_func(self, mod, fn, func):
self._all_wrappers.append((mod, fn, func))
def _deactivate(self):
for mod, fn, func in self._all_wrappers:
utils.set_func(mod, fn, func)
self._all_wrappers = []
@property
def has_cache(self):
return self._enable_caching
......
from . import utils, wrap
import torch
_VF = torch._C._VariableFunctions
RNN_NAMES = ['rnn_relu', 'rnn_tanh', 'gru', 'lstm']
def _gen_VF_wrapper(name):
def wrapper(*args, **kwargs):
return getattr(_VF, name)(*args, **kwargs)
return wrapper
# Some python magic to generate an object that has the rnn cell functions
# defined on it, all of which call into corresponding _VF version.
class VariableFunctionsShim(object):
def __init__(self):
for name in RNN_NAMES:
setattr(self, name + '_cell', _gen_VF_wrapper(name + '_cell'))
def has_old_rnns():
try:
torch.nn.backends.thnn.backend.LSTMCell
return True
except:
return False
def whitelist_rnn_cells(handle, verbose):
# Different module + function names in old/new RNN cases
if has_old_rnns():
fn_names = ['RNNReLUCell', 'RNNTanhCell', 'LSTMCell', 'GRUCell']
mod = torch.nn.backends.thnn.backend
else:
fn_names = [x + '_cell' for x in RNN_NAMES]
mod = torch.nn.modules.rnn._VF
assert isinstance(mod, VariableFunctionsShim)
# Insert casts on cell functions
for fn in fn_names:
wrap.cached_cast(mod, fn, utils.maybe_half, handle,
try_caching=True, verbose=verbose)
if has_old_rnns():
# Special handling of `backward` for fused gru / lstm:
# The `backward` method calls Tensor.sum() (blacklist) internally,
# and then the resulting grad_input has the wrong type.
# TODO: where else is this a problem?
for rnn_type in ['GRUFused', 'LSTMFused']:
mod = getattr(torch.nn._functions.thnn.rnnFusedPointwise, rnn_type)
wrap.disable_casts(mod, 'backward', handle)
......@@ -111,21 +111,32 @@ def as_inplace(fns):
def has_func(mod, fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
return fn in mod.function_classes
elif isinstance(mod, dict):
return fn in mod
else:
return hasattr(mod, fn)
def get_func(mod, fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
return mod.function_classes[fn]
elif isinstance(mod, dict):
return mod[fn]
else:
return getattr(mod, fn)
def set_func(mod, fn, new_fn):
if isinstance(mod, torch.nn.backends.backend.FunctionBackend):
mod.function_classes[fn] = new_fn
elif isinstance(mod, dict):
mod[fn] = new_fn
else:
setattr(mod, fn, new_fn)
def set_func_save(handle, mod, fn, new_fn):
cur_fn = get_func(mod, fn)
handle._save_func(mod, fn, cur_fn)
set_func(mod, fn, new_fn)
# A couple problems get solved here:
# - The flat_weight buffer is disconnected from autograd graph,
# so the fp16 weights need to be derived from the input weights
......@@ -160,3 +171,23 @@ def synthesize_flattened_rnn_weights(fp32_weights,
fp16_layer_weights.append(w_fp16)
fp16_weights.append(fp16_layer_weights)
return fp16_weights
# Roughly same as above, just the `fp32_weights` aren't nested.
# Code kept separate for readability.
def new_synthesize_flattened_rnn_weights(fp32_weights,
fp16_flat_tensor,
rnn_fn='',
verbose=False):
fp16_weights = []
fp32_base_ptr = fp32_weights[0].data_ptr()
for w_fp32 in fp32_weights:
w_fp16 = w_fp32.new().half()
offset = (w_fp32.data_ptr() - fp32_base_ptr) // w_fp32.element_size()
w_fp16.set_(fp16_flat_tensor.storage(),
offset,
w_fp32.shape)
w_fp16.copy_(w_fp32)
if verbose:
print('Float->Half ({})'.format(rnn_fn))
fp16_weights.append(w_fp16)
return fp16_weights
......@@ -34,7 +34,7 @@ def cached_cast(mod, fn, cast_fn, handle,
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(cast_fn, fn, verbose)
wrapper = make_cast_wrapper(orig_fn, cast_fn, handle, try_caching)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
# `handle` arg is unused, but simplifies API to make `make_cast_wrapper`
def make_promote_wrapper(orig_fn, cast_fn, handle=None):
......@@ -54,13 +54,13 @@ def make_promote_wrapper(orig_fn, cast_fn, handle=None):
.format(types))
return wrapper
def promote(mod, fn, verbose=False):
def promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
wrapper = make_promote_wrapper(orig_fn, maybe_float)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def sequence_promote(mod, fn, verbose=False):
def sequence_promote(mod, fn, handle, verbose=False):
orig_fn = utils.get_func(mod, fn)
maybe_float = utils.verbosify(utils.maybe_float, fn, verbose)
@functools.wraps(orig_fn)
......@@ -76,9 +76,9 @@ def sequence_promote(mod, fn, verbose=False):
# TODO: other mixed-type cases aren't due to amp.
# Just pass through?
return orig_fn(seq, *args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def promote_match_arg0(mod, fn, verbose=False):
def promote_match_arg0(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
......@@ -95,9 +95,9 @@ def promote_match_arg0(mod, fn, verbose=False):
cast_fn = utils.verbosify(cast_fn, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_any_half(mod, fn, custom_err_msg=None):
def err_if_any_half(mod, fn, handle, custom_err_msg=None):
if not utils.has_func(mod, fn):
return
......@@ -113,9 +113,9 @@ def err_if_any_half(mod, fn, custom_err_msg=None):
'{} with fp16 arguments.'.format(fn))
else:
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
def err_if_arg0_half(mod, fn, verbose=False):
def err_if_arg0_half(mod, fn, handle, verbose=False):
if not utils.has_func(mod, fn):
return
......@@ -130,7 +130,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
cast_fn = utils.verbosify(utils.maybe_float, fn, verbose)
new_args = utils.casted_args(cast_fn, args, kwargs)
return orig_fn(arg0, *new_args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
# Current RNN approach:
# - Wrap top-level `RNN` function in thnn backend
......@@ -140,7 +140,7 @@ def err_if_arg0_half(mod, fn, verbose=False):
# - We interpose on the factory function to:
# 1) Interpose on the actual forward function and put in casts
# 2) Insert an fp16 `flat_weight` if necessary
def rnn_cast(backend, fn, verbose=False):
def rnn_cast(backend, fn, handle, verbose=False):
orig_rnn = utils.get_func(backend, fn)
@functools.wraps(orig_rnn)
def rnn_wrapper(*args, **kwargs):
......@@ -203,7 +203,39 @@ def rnn_cast(backend, fn, verbose=False):
return forward(*new_args, **fkwargs)
return fwd_wrapper
utils.set_func(backend, fn, rnn_wrapper)
utils.set_func_save(handle, backend, fn, rnn_wrapper)
def new_rnn_cast(fn, handle, verbose=False):
mod = torch.nn.modules.rnn._rnn_impls
orig_fn = utils.get_func(mod, fn)
cast_fn = utils.verbosify(utils.maybe_half, fn, verbose)
@functools.wraps(orig_fn)
def wrapper(*args, **kwargs):
# Exact call signature from modules/rnn.py
assert len(args) == 9
assert len(kwargs) == 0
if isinstance(args[6], bool):
params_idx = 2 # Not PackedSequence case
else:
params_idx = 3 # PackedSequence case
new_args = []
for i, arg in enumerate(args):
if i == params_idx:
num_params = sum([x.numel() for x in arg])
fp16_weight_buf = args[0].new_empty((num_params,),
dtype=torch.half)
casted_weights = utils.new_synthesize_flattened_rnn_weights(
arg, fp16_weight_buf, fn, verbose)
new_args.append(casted_weights)
elif utils.is_fp_tensor(arg):
new_args.append(cast_fn(arg))
else:
new_args.append(arg)
return orig_fn(*new_args)
utils.set_func_save(handle, mod, fn, wrapper)
def disable_casts(mod, fn, handle):
if not utils.has_func(mod, fn):
......@@ -214,4 +246,4 @@ def disable_casts(mod, fn, handle):
def wrapper(*args, **kwargs):
with handle._disable_casts():
return orig_fn(*args, **kwargs)
utils.set_func(mod, fn, wrapper)
utils.set_func_save(handle, mod, fn, wrapper)
......@@ -40,6 +40,7 @@ class FP16_Optimizer(object):
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used.
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
``init_optimizer`` is expected to have been constructed in the ordinary way.
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
......@@ -105,10 +106,13 @@ class FP16_Optimizer(object):
init_optimizer,
static_loss_scale=1.0,
dynamic_loss_scale=False,
dynamic_loss_args=None):
dynamic_loss_args=None,
verbose=True):
if not torch.cuda.is_available:
raise SystemError("Cannot use fp16 without CUDA.")
self.verbose = verbose
self.optimizer = init_optimizer
# init_state_dict sets up an alternative way to cast per-param state tensors.
# Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
......@@ -118,15 +122,15 @@ class FP16_Optimizer(object):
self.fp32_from_fp16_groups = []
self.fp32_from_fp32_groups = []
for i, param_group in enumerate(self.optimizer.param_groups):
print("FP16_Optimizer processing param group {}:".format(i))
self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
fp32_from_fp16_params_this_group = []
for i, param in enumerate(param_group['params']):
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
.format(param.size()))
self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
.format(param.size()))
fp16_params_this_group.append(param)
master_param = param.detach().clone().float()
master_param.requires_grad = True
......@@ -137,8 +141,8 @@ class FP16_Optimizer(object):
if param in self.optimizer.state:
self.optimizer.state[master_param] = self.optimizer.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor':
print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size()))
self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size()))
fp32_params_this_group.append(param)
param_group['params'][i] = param
else:
......@@ -169,6 +173,10 @@ class FP16_Optimizer(object):
self.first_closure_call_this_step = True
self.clip_grad_norm = clip_grad_norm
def maybe_print(self, msg):
if self.verbose:
print(msg)
def __getstate__(self):
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
......@@ -176,20 +184,31 @@ class FP16_Optimizer(object):
def __setstate__(self, state):
raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")
def zero_grad(self):
def zero_grad(self, set_grads_to_None=False):
"""
Zero fp32 and fp16 parameter grads.
"""
# In principle, only the .grad attributes of the model params need to be zeroed,
# because gradients are copied into the FP32 master params. However, we zero
# all gradients owned by the optimizer, just to be safe:
self.optimizer.zero_grad()
for group in self.optimizer.param_groups:
for p in group['params']:
if set_grads_to_None:
p.grad = None
else:
if p.grad is not None:
p.grad.detach_()
p.grad.zero_()
# Zero fp16 gradients owned by the model:
for fp16_group in self.fp16_groups:
for param in fp16_group:
if param.grad is not None:
param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_()
if set_grads_to_None:
param.grad = None
else:
if param.grad is not None:
param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_()
def _check_overflow(self):
params = []
......
......@@ -4,8 +4,26 @@ import torch.distributed as dist
from torch.nn.modules import Module
from torch.autograd import Variable
from collections import OrderedDict
from itertools import chain
import copy
# apply_dist_call requires that tensors in 'bucket' are all the same type.
def apply_flat_dist_call(bucket, call, extra_args=None):
coalesced = _flatten_dense_tensors(bucket)
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
if extra_args is not None:
call(coalesced, *extra_args)
else:
call(coalesced)
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
# flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None):
flat_dist_call.warn_on_half = True
buckets = OrderedDict()
......@@ -15,27 +33,11 @@ def flat_dist_call(tensors, call, extra_args=None):
buckets[tp] = []
buckets[tp].append(tensor)
if flat_dist_call.warn_on_half:
if torch.cuda.HalfTensor in buckets:
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
" It is recommended to use the NCCL backend in this case.")
flat_dist_call.warn_on_half = False
for tp in buckets:
bucket = buckets[tp]
coalesced = _flatten_dense_tensors(bucket)
if extra_args is not None:
call(coalesced, *extra_args)
else:
call(coalesced)
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)):
buf.copy_(synced)
apply_flat_dist_call(bucket, call, extra_args)
def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor)
......@@ -117,173 +119,235 @@ class DistributedDataParallel(Module):
Args:
module: Network definition to be run in multi-gpu/distributed mode.
message_size (Default = 1e7): Minimum number of elements in a communication bucket.
shared_param (Default = False): If your model uses shared parameters this must be True. It will disable bucketing of parameters to avoid race conditions.
delay_allreduce (Default = False): Delay all communication to the end of the backward pass. This disables overlapping communication with computation.
"""
def __init__(self, module, message_size=10000000, shared_param=False):
def __init__(self, module, message_size=10000000, delay_allreduce=False, shared_param=None):
super(DistributedDataParallel, self).__init__()
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
self.shared_param = shared_param
# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36
if(hasattr(dist, "get_backend")):
self._backend = dist.get_backend()
self.backend_enum_holder = dist.DistBackend
else:
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
self.delay_allreduce = delay_allreduce
self.message_size = message_size
#reference to last iterations parameters to see if anything has changed
self.param_refs = []
self.reduction_stream = torch.cuda.Stream()
self.module = module
self.param_list = list(self.module.parameters())
if dist._backend == dist.dist_backend.NCCL:
for param in self.param_list:
if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.record = []
self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
self.create_hooks()
flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) )
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
def __getstate__(self):
attrs = copy.copy(self.__dict__)
if dist._backend != dist.dist_backend.NCCL:
if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream']
return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
def sync_bucket_structure(self):
# Append leftover buckets
for tmp_bucket in self.tmp_buckets:
if len(tmp_bucket) > 0:
self.buckets.append(tmp_bucket)
self.num_buckets = len(self.buckets)
self.bucket_sizes = [len(bucket) for bucket in self.buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
list(chain(*self.buckets)))
dist.broadcast(info_tensor, 0)
info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])] for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0
for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i]
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1
def create_hooks(self):
#all reduce gradient hook
# Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later
# iterations.
def allreduce_params():
if not self.needs_reduction:
return
self.needs_reduction = False
# Bucket record refresh
if not self.delay_allreduce:
if self.needs_refresh:
self.sync_bucket_structure()
#parameter ordering refresh
if self.needs_refresh and not self.shared_param:
t_record = torch.cuda.IntTensor(self.record)
dist.broadcast(t_record, 0)
self.record = [int(entry) for entry in t_record]
self.needs_refresh = False
self.needs_refresh = False
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
def flush_buckets():
if not self.needs_reduction:
return
self.needs_reduction = False
grads = []
for i in range(self.ready_end, len(self.param_state)):
param = self.param_refs[self.record[i]]
if param.grad is not None:
grads.append(param.grad.data)
grads = [param.grad.data for param in self.ready_params] + grads
if(len(grads)>0):
orig_stream = torch.cuda.current_stream()
with torch.cuda.stream(self.reduction_stream):
self.reduction_stream.wait_stream(orig_stream)
flat_dist_call(grads, dist.all_reduce)
def overlapping_backward_epilogue():
torch.cuda.current_stream().wait_stream(self.reduction_stream)
# Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket != num_buckets. "
"This probably indicates some buckets were not allreduced.")
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = []
for param in self.module.parameters():
if param.requires_grad:
def wrapper(param):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
# each forward, e.g., backward passes with retain_graph=True?
# needs_refresh and callback_queued are both vulnerable states.
if not self.delay_allreduce and self.needs_refresh:
# Use the backward pass to build the bucket structure on the fly.
active_i = self.param_id_to_active_i[id(param)]
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
self.tmp_numels[current_type] += param.numel()
if self.tmp_numels[current_type] >= self.message_size:
self.buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True
else:
if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True
self.comm_ready_buckets(param)
grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc)
wrapper(param)
def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream())
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
if self.buckets[bucket_idx][bucket_loc] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error.")
self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket:
self.reduction_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.reduction_stream):
apply_flat_dist_call(self.buckets[bucket_idx], dist.all_reduce)
self.next_bucket += 1
# Reversing upstream's logic here, because we constructed our buckets based on
# the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
sorted_todo = sorted(self.ready_buckets_not_reduced)
for i in sorted_todo:
# Nothing can be reduced now
if i > self.next_bucket:
break
elif i == self.next_bucket:
apply_flat_dist_call(self.buckets[i], dist.all_reduce)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
else:
self.ready_buckets_not_reduced.add(bucket_idx)
for param_i, param in enumerate([p for p in self.module.parameters() if p.requires_grad]):
def wrapper(param_i):
def allreduce_hook(*unused):
if self.needs_refresh:
self.record.append(param_i)
Variable._execution_engine.queue_callback(allreduce_params)
else:
Variable._execution_engine.queue_callback(flush_buckets)
self.comm_ready_buckets(self.record.index(param_i))
if param.requires_grad:
param.register_hook(allreduce_hook)
wrapper(param_i)
def comm_ready_buckets(self, param_ind):
if self.param_state[param_ind] != 0:
raise RuntimeError("Error: Your model uses shared parameters, DDP flag shared_params must be set to True in initialization.")
if self.param_state[self.ready_end] == 0:
self.param_state[param_ind] = 1
return
while self.ready_end < len(self.param_state) and self.param_state[self.ready_end] == 1:
self.ready_params.append(self.param_refs[self.record[self.ready_end]])
self.ready_numel += self.ready_params[-1].numel()
self.ready_end += 1
if self.ready_numel < self.message_size:
self.param_state[param_ind] = 1
return
grads = [param.grad.data for param in self.ready_params]
bucket = []
bucket_inds = []
while grads:
bucket.append(grads.pop(0))
cumm_size = 0
for ten in bucket:
cumm_size += ten.numel()
if cumm_size < self.message_size:
continue
evt = torch.cuda.Event()
evt.record(torch.cuda.current_stream())
evt.wait(stream=self.reduction_stream)
with torch.cuda.stream(self.reduction_stream):
flat_dist_call(bucket, dist.all_reduce)
for i in range(self.ready_start, self.ready_start+len(bucket)):
self.param_state[i] = 2
self.ready_params.pop(0)
self.param_state[param_ind] = 1
def forward(self, *inputs, **kwargs):
param_list = [param for param in self.module.parameters() if param.requires_grad]
#Force needs_refresh to True if there are shared params
#this will force it to always, only call flush_buckets which is safe
#for shared parameters in the model.
#Parentheses are not necessary for correct order of operations, but make the intent clearer.
if (not self.param_refs) or self.shared_param:
self.needs_refresh = True
else:
self.needs_refresh = (
(len(param_list) != len(self.param_refs)) or any(
[param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)]))
if self.needs_refresh:
self.record = []
result = self.module(*inputs, **kwargs)
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
if self.needs_refresh:
self.buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
self.param_state = [0 for i in range(len(param_list))]
self.param_refs = param_list
self.needs_reduction = True
self.ready_start = 0
self.ready_end = 0
self.ready_params = []
self.ready_numel = 0
self.active_params = param_list
self.callback_queued = False
return self.module(*inputs, **kwargs)
return result
......@@ -26,7 +26,7 @@ python main.py -a alexnet --lr 0.01 /path/to/imagenet/folder
```
The directory at /path/to/imagenet/directory should contain two subdirectories called "train"
and "val" that contain the training and validation data respectively. Train images are expected to be 256x256 jpegs.
and "val" that contain the training and validation data respectively.
## Distributed training
......@@ -46,19 +46,17 @@ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main.py args...
```bash
### Softlink training dataset into current directory
$ ln -sf /data/imagenet/train-jpeg-256x256/ train
$ ln -sf /data/imagenet/train-jpeg/ train
### Softlink validation dataset into current directory
$ ln -sf /data/imagenet/val-jpeg/ val
### Single-process training
$ python main.py -a resnet50 --fp16 --b 256 --workers 4 ./
### Multi-process training (uses all visible GPU on the node)
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main.py -a resnet50 --fp16 --b 256 --workers 4 ./
$ python main.py -a resnet50 --fp16 --b 256 --workers 4 --static-loss-scale 128.0 ./
### Multi-process training (uses all visible GPUs on the node)
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main.py -a resnet50 --fp16 --b 256 --workers 4 --static-loss-scale 128.0 ./
### Multi-process training on GPUs 0 and 1 only
$ export CUDA_VISIBLE_DEVICES=0,1
$ python -m torch.distributed.launch --nproc_per_node=2 main.py -a resnet50 --fp16 --b 256 --workers 4 ./
### Multi-process training with FP16_Optimizer, default loss scale 1.0 (still uses FP32 master params)
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_fp16_optimizer.py -a resnet50 --fp16 --b 256 --workers 4 ./
# Multi-process training with FP16_Optimizer, static loss scale
### Multi-process training with FP16_Optimizer, static loss scale 128.0 (still uses FP32 master params)
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_fp16_optimizer.py -a resnet50 --fp16 --b 256 --static-loss-scale 128.0 --workers 4 ./
### Multi-process training with FP16_Optimizer, dynamic loss scaling
$ python -m torch.distributed.launch --nproc_per_node=NUM_GPUS main_fp16_optimizer.py -a resnet50 --fp16 --b 256 --dynamic-loss-scale --workers 4 ./
......
......@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N',
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
......@@ -89,6 +89,7 @@ def fast_collate(batch):
best_prec1 = 0
args = parser.parse_args()
def main():
global best_prec1, args
......@@ -121,8 +122,11 @@ def main():
if args.fp16:
model = network_to_half(model)
if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model = DDP(model)
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
global model_params, master_params
if args.fp16:
......@@ -133,25 +137,32 @@ def main():
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size
args.lr = args.lr*float(args.batch_size)/256.
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(master_params, args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionally resume from a checkpoint
# Optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
if args.fp16:
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
......@@ -213,13 +224,19 @@ def main():
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best)
# Use local scope to avoid dangling references
def create_and_save_checkpoint():
checkpoint_dict = {
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}
if args.fp16:
checkpoint_dict['master_params'] = master_params
save_checkpoint(checkpoint_dict, is_best)
create_and_save_checkpoint()
class data_prefetcher():
def __init__(self, loader):
......
......@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N',
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
......@@ -128,14 +128,17 @@ def main():
if args.fp16:
model = network_to_half(model)
if args.distributed:
# shared param turns off bucketing in DDP, for lower latency runs this can improve perf
model = DDP(model, shared_param=True)
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size
args.lr = args.lr*float(args.batch_size)/256.
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
......@@ -144,19 +147,23 @@ def main():
static_loss_scale=args.static_loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale)
# optionally resume from a checkpoint
# Optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
# An FP16_Optimizer instance's state dict internally stashes the master params.
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
......
......@@ -43,9 +43,9 @@ parser.add_argument('--epochs', default=90, type=int, metavar='N',
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
......@@ -133,25 +133,32 @@ def main():
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# Scale learning rate based on per-process batch size
args.lr = args.lr*float(args.batch_size)/256.
# Scale learning rate based on global batch size
args.lr = args.lr*float(args.batch_size*args.world_size)/256.
optimizer = torch.optim.SGD(master_params, args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
# optionally resume from a checkpoint
# Optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# Use a local scope to avoid dangling references
def resume():
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
if args.fp16:
saved_master_params = checkpoint['master_params']
for master, saved in zip(master_params, saved_master_params):
master.data.copy_(saved.data)
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
resume()
# Data loading code
traindir = os.path.join(args.data, 'train')
......@@ -213,13 +220,19 @@ def main():
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, is_best)
# Use local scope to avoid dangling references
def create_and_save_checkpoint():
checkpoint_dict = {
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}
if args.fp16:
checkpoint_dict['master_params'] = master_params
save_checkpoint(checkpoint_dict, is_best)
create_and_save_checkpoint()
class data_prefetcher():
def __init__(self, loader):
......@@ -307,7 +320,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
if args.fp16:
model.zero_grad()
loss.backward()
reducer.reduce()
if args.distributed:
reducer.reduce()
model_grads_to_master_grads(model_params, master_params)
if args.static_loss_scale != 1:
for param in master_params:
......@@ -317,7 +331,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
else:
optimizer.zero_grad()
loss.backward()
reducer.reduce()
if args.distributed:
reducer.reduce()
optimizer.step()
torch.cuda.synchronize()
......
......@@ -4,32 +4,26 @@ from torch.nn import Parameter
from torch.nn import Module
from apex.parallel import DistributedDataParallel as DDP
import argparse
import os
parser = argparse.ArgumentParser(description='allreduce hook example')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
help='distributed backend')
parser.add_argument('--world-size', default=1, type=int,
help='Number of GPUs to use. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument('--rank', default=0, type=int,
help='Used for multi-process training. Can either be manually set ' +
'or automatically set by using \'python -m multiproc\'.')
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
args.distributed = args.world_size > 1
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed:
torch.cuda.set_device(args.rank % torch.cuda.device_count())
dist.init_process_group(args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank)
args.gpu = args.local_rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
torch.set_printoptions(precision=10)
torch.manual_seed(args.local_rank)
class Model(Module):
def __init__(self):
......@@ -40,24 +34,31 @@ class Model(Module):
return (input*self.a)*self.b
model = DDP(Model(), message_size=1)
# model = DDP(Model(), delay_allreduce=True)
x = torch.cuda.FloatTensor(4096*4096)
passed = True
for i in range(10):
x.fill_(i + args.rank) # fill x with new values every iteration for sanity
x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
model.zero_grad()
out = model(x)
loss = out.sum()
torch.cuda.nvtx.range_push("backward")
# torch.cuda.nvtx.range_push("backward")
loss.backward()
torch.cuda.nvtx.range_pop()
# torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("synchronize() + info")
torch.cuda.synchronize()
# torch.cuda.nvtx.range_push("synchronize() + info")
# torch.cuda.synchronize()
print("i = {}".format(i))
def info(name, param, val):
expected = val*4096*4096*(2.*i+1)/2.
actual = param.grad.data.sum().item()
print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(
param.grad.data_ptr(), val*4096*4096*(2.*i+1)/2., param.grad.data.sum().item()))
info("model.a", model.module.a, 2.)
info("model.b", model.module.b, 1.)
torch.cuda.nvtx.range_pop()
param.grad.data_ptr(), expected, actual))
return (expected == actual)
if not info("model.a", model.module.a, 2.): passed = False
if not info("model.b", model.module.b, 1.): passed = False
# torch.cuda.nvtx.range_pop()
print("passed = ", passed)
#!/bin/bash
CUDA_VISIBLE_DEVICES=0,1 python -m apex.parallel.multiproc ddp_race_condition.py
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 ddp_race_condition_test.py
import unittest
import functools as ft
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
for fn, typ in it.product(fns, expected.keys()):
x = torch.randn(input_shape, dtype=typ).requires_grad_()
y = fn(x)
test_case.assertEqual(y.type(), expected[typ])
if test_backward:
y.float().sum().backward()
test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
class TestBasicCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_linear_is_half(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
def test_conv2d_is_half(self):
m = nn.Conv2d(self.c, self.c, self.k)
f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
def test_softmax_is_float(self):
m = nn.Softmax(dim=1)
f = ft.partial(F.softmax, dim=1)
run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
def test_group_norm_is_float(self):
m = nn.GroupNorm(num_groups=4, num_channels=self.c)
run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
def test_mse_loss_is_float(self):
shape = (self.b, self.h)
target = torch.randn(shape)
mod = nn.MSELoss()
m = lambda x: mod(x, target)
f = ft.partial(F.mse_loss, target=target)
run_layer_test(self, [m], ALWAYS_FLOAT, shape)
def test_relu_is_match(self):
run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
def test_batch_norm_is_match(self):
m = nn.BatchNorm2d(num_features=self.c)
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=True)
run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
# Test forward-only for BN inference
m.eval()
f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
weight=m.weight, bias=m.bias, training=False)
run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
test_backward=False)
class TestBannedMethods(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def bce_common(self, assertion):
shape = (self.b, self.h)
target = torch.rand(shape)
mod = nn.BCELoss()
m = lambda x: mod(x, target)
f = ft.partial(F.binary_cross_entropy, target=target)
for fn in [m, f]:
x = torch.rand(shape, dtype=torch.half)
assertion(fn, x)
def test_bce_raises_by_default(self):
assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
self.bce_common(assertion)
def test_bce_is_float_with_allow_banned(self):
self.handle._deactivate()
self.handle = amp.init(enabled=True, allow_banned=True)
assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
self.bce_common(assertion)
class TestTensorCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def test_matmul_method_is_half(self):
other = torch.randn(self.h, self.h)
lhs = lambda x: x.matmul(other)
rhs = lambda x: other.matmul(x)
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
def test_matmul_op_is_half(self):
other = torch.randn(self.h, self.h)
lhs = lambda x: x @ other
rhs = lambda x: other @ x
run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
def test_pow_method_is_float(self):
fn = lambda x: x.pow(2.)
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
def test_pow_op_is_float(self):
fn = lambda x: x ** 2.
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
def test_cpu_is_float(self):
fn = lambda x: x.cpu()
always_cpu_float = {torch.float: 'torch.FloatTensor',
torch.half: 'torch.FloatTensor'}
run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))
def test_sum_is_float(self):
fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestDisabledCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=False)
common_init(self)
def test_disabled_linear(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
input_shape = (self.b, self.h)
for fn in [m, f]:
x = torch.randn(input_shape, dtype=torch.float).requires_grad_()
y = fn(x)
self.assertEqual(y.type(), FLOAT)
y.sum().backward()
self.assertEqual(x.grad.type(), FLOAT)
x = torch.randn(input_shape, dtype=torch.half).requires_grad_()
self.assertRaises(RuntimeError, fn, x)
# TODO: maybe more tests on disabled casting?
if __name__ == '__main__':
unittest.main()
import unittest
import itertools as it
from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from utils import common_init, HALF, FLOAT, DTYPES
class TestPromotion(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_binary_promote_test(self, fns, input_shape, x_inplace=False):
type_pairs = it.product(DTYPES, DTYPES)
for fn, (xtype, ytype) in it.product(fns, type_pairs):
x = torch.randn(input_shape, dtype=xtype).requires_grad_()
x_leaf = x
if x_inplace:
# We need a non-leaf to call in place on
x = x.clone()
y = torch.randn(input_shape, dtype=ytype)
out = fn(x, y)
if x_inplace:
# In place: always match xtype
self.assertEqual(out.type(), x.type())
else:
# Out of place: match widest type
if xtype == torch.float or ytype == torch.float:
self.assertEqual(out.type(), FLOAT)
else:
self.assertEqual(out.type(), HALF)
out.float().sum().backward()
self.assertEqual(x_leaf.grad.dtype, xtype)
def test_atan2_matches_widest(self):
fns = [lambda x, y : torch.atan2(x, y),
lambda x, y : x.atan2(y)]
self.run_binary_promote_test(fns, (self.b,))
def test_mul_matches_widest(self):
fns = [lambda x, y : torch.mul(x, y),
lambda x, y: x.mul(y)]
self.run_binary_promote_test(fns, (self.b,))
def test_cat_matches_widest(self):
shape = self.b
ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
x_float = torch.randn(shape)
out = torch.cat(ys + [x_float])
self.assertEqual(out.type(), FLOAT)
x_half = torch.randn(shape, dtype=torch.half)
out = torch.cat(ys + [x_half])
self.assertEqual(out.type(), HALF)
def test_inplace_exp_is_error_for_half(self):
xs = torch.randn(self.b)
xs.exp_()
self.assertEqual(xs.type(), FLOAT)
xs = torch.randn(self.b, dtype=torch.half)
with self.assertRaises(NotImplementedError):
xs.exp_()
def test_inplace_add_matches_self(self):
fn = lambda x, y: x.add_(y)
self.run_binary_promote_test([fn], (self.b,), x_inplace=True)
if __name__ == '__main__':
unittest.main()
import unittest
from apex import amp
import random
import torch
from torch import nn
from utils import common_init, HALF
class TestRnnCells(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_cell_test(self, cell, state_tuple=False):
shape = (self.b, self.h)
for typ in [torch.float, torch.half]:
xs = [torch.randn(shape, dtype=typ).requires_grad_()
for _ in range(self.t)]
hidden_fn = lambda: torch.zeros(shape, dtype=typ)
if state_tuple:
hidden = (hidden_fn(), hidden_fn())
else:
hidden = hidden_fn()
outputs = []
for i in range(self.t):
hidden = cell(xs[i], hidden)
if state_tuple:
output = hidden[0]
else:
output = hidden
outputs.append(output)
for y in outputs:
self.assertEqual(y.type(), HALF)
outputs[-1].float().sum().backward()
for i, x in enumerate(xs):
self.assertEqual(x.grad.dtype, x.dtype)
def test_rnn_cell_is_half(self):
cell = nn.RNNCell(self.h, self.h)
self.run_cell_test(cell)
def test_gru_cell_is_half(self):
cell = nn.GRUCell(self.h, self.h)
self.run_cell_test(cell)
def test_lstm_cell_is_half(self):
cell = nn.LSTMCell(self.h, self.h)
self.run_cell_test(cell, state_tuple=True)
class TestRnns(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=True)
common_init(self)
def tearDown(self):
self.handle._deactivate()
def run_rnn_test(self, rnn, layers, bidir, state_tuple=False):
for typ in [torch.float, torch.half]:
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
hidden_fn = lambda: torch.zeros((layers + (layers * bidir),
self.b, self.h), dtype=typ)
if state_tuple:
hidden = (hidden_fn(), hidden_fn())
else:
hidden = hidden_fn()
output, _ = rnn(x, hidden)
self.assertEqual(output.type(), HALF)
output[-1, :, :].float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype)
def test_rnn_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=layers,
nonlinearity='relu', bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
def test_gru_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.GRU(input_size=self.h, hidden_size=self.h, num_layers=layers,
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir)
def test_lstm_is_half(self):
configs = [(1, False), (2, False), (2, True)]
for layers, bidir in configs:
rnn = nn.LSTM(input_size=self.h, hidden_size=self.h, num_layers=layers,
bidirectional=bidir)
self.run_rnn_test(rnn, layers, bidir, state_tuple=True)
def test_rnn_packed_sequence(self):
num_layers = 2
rnn = nn.RNN(input_size=self.h, hidden_size=self.h, num_layers=num_layers)
for typ in [torch.float, torch.half]:
x = torch.randn((self.t, self.b, self.h), dtype=typ).requires_grad_()
lens = sorted([random.randint(self.t // 2, self.t) for _ in range(self.b)],
reverse=True)
# `pack_padded_sequence` breaks if default tensor type is non-CPU
torch.set_default_tensor_type(torch.FloatTensor)
lens = torch.tensor(lens, dtype=torch.int64, device=torch.device('cpu'))
packed_seq = nn.utils.rnn.pack_padded_sequence(x, lens)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
hidden = torch.zeros((num_layers, self.b, self.h), dtype=typ)
output, _ = rnn(packed_seq, hidden)
self.assertEqual(output.data.type(), HALF)
output.data.float().sum().backward()
self.assertEqual(x.grad.dtype, x.dtype)
if __name__ == '__main__':
unittest.main()
import torch
HALF = 'torch.cuda.HalfTensor'
FLOAT = 'torch.cuda.FloatTensor'
DTYPES = [torch.half, torch.float]
ALWAYS_HALF = {torch.float: HALF,
torch.half: HALF}
ALWAYS_FLOAT = {torch.float: FLOAT,
torch.half: FLOAT}
MATCH_INPUT = {torch.float: FLOAT,
torch.half: HALF}
def common_init(test_case):
test_case.h = 64
test_case.b = 16
test_case.c = 16
test_case.k = 3
test_case.t = 10
torch.set_default_tensor_type(torch.cuda.FloatTensor)
import unittest
import functools as ft
import itertools as it
import torch
from apex.fp16_utils import FP16_Optimizer
class TestFP16Optimizer(unittest.TestCase):
def setUp(self):
N, D_in, D_out = 64, 1024, 16
self.N = N
self.D_in = D_in
self.D_out = D_out
self.x = torch.randn((N, D_in), dtype=torch.float16, device='cuda')
self.y = torch.randn((N, D_out), dtype=torch.float16, device='cuda')
self.model = torch.nn.Linear(D_in, D_out).cuda().half()
# def tearDown(self):
# pass
def test_minimal(self):
pass
def test_minimal_static(self):
pass
def test_minimal_dynamic(self):
pass
def test_closure(self):
pass
def test_closure_dynamic(self):
pass
def test_save_load(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment