Commit 4d6ed501 authored by Deyu Fu's avatar Deyu Fu
Browse files

Merge branch 'multi_tensor_sgd' into deyuf/fused_optimizer_v2

parents 690b1f71 9f64bf27
import torch
import xentropy_cuda
class SoftmaxCrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
losses, max_log_sum_exp = xentropy_cuda.forward(
logits, labels, smoothing, half_to_float)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels,
torch.FloatTensor([smoothing]),
torch.LongTensor([padding_idx]))
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==padding_idx.item(), 0)
grad_logits = xentropy_cuda.backward(
grad_loss.contiguous(), logits, max_log_sum_exp,
labels, smoothing.item())
return grad_logits, None, None, None, None
from .sgd import FusedSGD
from .fused_sgd import FusedSGD
from .novograd import FusedNovoGrad
from .fused_adam_v1 import FusedAdam_v1
from .adam import FusedAdam
#from .sgd import FusedSGD
from .fp16_optimizer import FP16_Optimizer
......@@ -2,8 +2,9 @@ import types
import torch
import importlib
class FusedAdam_v1(torch.optim.Optimizer):
from ..multi_tensor_apply import multi_tensor_applier
class FusedAdam_v1(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``.
......@@ -25,6 +26,8 @@ class FusedAdam_v1(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -35,10 +38,21 @@ class FusedAdam_v1(torch.optim.Optimizer):
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
......@@ -66,6 +80,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
if closure is not None:
loss = closure()
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms
if grads is None:
grads_group = [None]*len(self.param_groups)
# backward compatibility
......@@ -105,6 +125,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
......@@ -130,18 +156,43 @@ class FusedAdam_v1(torch.optim.Optimizer):
state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad]
if output_param is not None:
pl.append(out_p)
for tl, t in zip(tensorlists, pl):
tl.append(t)
else:
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss
import torch
from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs = [True, True]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params:
stash = self._amp_stash
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum,
1.0/self.most_recent_scale)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
......@@ -44,7 +44,7 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
if call is dist.all_reduce:
coalesced /= dist.get_world_size()
for buf, synced in zip(bucket, unflatten(coalesced, bucket)):
buf.copy_(synced)
......@@ -54,7 +54,7 @@ def split_half_float_double(tensors):
for i, dtype in enumerate(dtypes):
bucket = [t for t in tensors if t.type() == dtype]
if bucket:
buckets.append(bucket)
buckets.append(bucket)
return buckets
def split_by_type(tensors):
......@@ -69,12 +69,12 @@ def split_by_type(tensors):
# flat_dist_call organizes 'tensors' by type.
def flat_dist_call(tensors, call, extra_args=None):
buckets = split_by_type(tensors)
for tp in buckets:
bucket = buckets[tp]
apply_flat_dist_call(bucket, call, extra_args)
def extract_tensors(maybe_tensor, tensor_list):
if torch.is_tensor(maybe_tensor):
tensor_list.append(maybe_tensor)
......@@ -85,7 +85,7 @@ def extract_tensors(maybe_tensor, tensor_list):
except TypeError:
return
class Reducer(object):
"""
:class:`apex.parallel.Reducer` is a simple class that helps allreduce a module's parameters
......@@ -93,13 +93,13 @@ class Reducer(object):
Unlike :class:`DistributedDataParallel`, :class:`Reducer` will not automatically allreduce
parameters during ``backward()``.
Instead, :class:`Reducer` waits for the user to call ``<reducer_instance>.reduce()`` manually.
This enables, for example, delaying the allreduce to be carried out every
This enables, for example, delaying the allreduce to be carried out every
several iterations instead of every single iteration.
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
Like :class:`DistributedDataParallel`, :class:`Reducer` averages any tensors it allreduces
over the number of participating processes.
:class:`Reducer` is designed to work with the upstream launch utility script
:class:`Reducer` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`Reducer` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
......@@ -107,7 +107,7 @@ class Reducer(object):
Args:
module_or_grads_list: Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they're all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module's parameters at the beginning of training.
"""
def __init__(self, module_or_grads_list):
if isinstance(module_or_grads_list, Module):
self.module = module_or_grads_list
......@@ -117,26 +117,26 @@ class Reducer(object):
self.module = None
self.grads = []
extract_tensors(module_or_grads_list, self.grads)
def reduce(self):
if self.module:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
flat_dist_call(grads, dist.all_reduce)
else:
flat_dist_call(self.grads, dist.all_reduce)
class DistributedDataParallel(Module):
"""
:class:`apex.parallel.DistributedDataParallel` is a module wrapper that enables
easy multiprocess distributed data parallel training, similar to ``torch.nn.parallel.DistributedDataParallel``. Parameters are broadcast across participating processes on initialization, and gradients are
allreduced and averaged over processes during ``backward()``.
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
:class:`DistributedDataParallel` is optimized for use with NCCL. It achieves high performance by
overlapping communication with computation during ``backward()`` and bucketing smaller gradient
transfers to reduce the total number of transfers required.
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script
:class:`DistributedDataParallel` is designed to work with the upstream launch utility script
``torch.distributed.launch`` with ``--nproc_per_node <= number of gpus per node``.
When used with this launcher, :class:`DistributedDataParallel` assumes 1:1 mapping of processes to GPUs.
It also assumes that your script calls ``torch.cuda.set_device(args.rank)`` before creating the model.
......@@ -159,19 +159,23 @@ class DistributedDataParallel(Module):
"""
def __init__(self,
module,
message_size=10000000,
delay_allreduce=False,
def __init__(self,
module,
message_size=10000000,
delay_allreduce=False,
shared_param=None,
allreduce_trigger_params=None,
retain_allreduce_buffers=False,
allreduce_always_fp32=False,
num_allreduce_streams=1,
allreduce_communicators=None,
gradient_average=True,
gradient_predivide_factor=1.0):
gradient_predivide_factor=1.0,
gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around
# Backward/forward compatibility around
# https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and
# https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86
if hasattr(dist, "get_backend"):
......@@ -181,13 +185,26 @@ class DistributedDataParallel(Module):
else:
self.backend_enum_holder = dist.Backend
else:
self._backend = dist._backend
self._backend = dist._backend
self.backend_enum_holder = dist.dist_backend
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
self.allreduce_different_streams = (num_allreduce_streams > 1)
self.num_allreduce_streams = num_allreduce_streams
self.allreduce_communicators = allreduce_communicators
if self.allreduce_communicators:
assert len(allreduce_communicators[0]) == num_allreduce_streams
assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])
assert self.allreduce_different_streams
if self.allreduce_different_streams and delay_allreduce:
raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
self.world_size = float(dist.get_world_size())
......@@ -199,27 +216,29 @@ class DistributedDataParallel(Module):
self.custom_allreduce_triggers = False
if allreduce_trigger_params is not None:
if delay_allreduce:
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
raise ValueError("Setting allreduce_trigger_params is only valid if delay_allreduce=False.")
self.custom_allreduce_triggers = True
self.allreduce_trigger_params = set([id(param) for param in allreduce_trigger_params])
self.delay_allreduce = delay_allreduce
self.message_size = message_size
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.main_stream = torch.cuda.current_stream()
self.bucket_streams = []
self.bucket_events = []
self.module = module
self._disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
self.active_params = []
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
self.param_type_to_tmp_i = {"torch.cuda.HalfTensor" : 0,
"torch.cuda.FloatTensor" : 1,
"torch.cuda.DoubleTensor" : 2}
......@@ -236,15 +255,21 @@ class DistributedDataParallel(Module):
def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
if self.allreduce_different_streams and delay_allreduce:
raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self):
attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream']
del attrs['self.reduction_event']
del attrs['self.bucket_streams']
del attrs['self.bucket_events']
return attrs
def enable_allreduce(self):
......@@ -252,9 +277,9 @@ class DistributedDataParallel(Module):
def disable_allreduce(self):
self._disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
def sync_bucket_structure(self):
# Append leftover buckets
for tmp_bucket in self.tmp_buckets:
......@@ -264,8 +289,8 @@ class DistributedDataParallel(Module):
self.num_buckets = len(self.active_i_buckets)
self.bucket_sizes = [len(bucket) for bucket in self.active_i_buckets]
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
info_tensor = torch.cuda.IntTensor([self.num_buckets] +
self.bucket_sizes +
list(chain(*self.active_i_buckets)))
dist.broadcast(info_tensor, 0)
......@@ -273,27 +298,27 @@ class DistributedDataParallel(Module):
info = [int(entry) for entry in info_tensor]
self.num_buckets = info[0]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.bucket_sizes = info[1:self.num_buckets + 1]
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
# Technically, active_i_buckets' work is done. But the information is still useful to
# keep around. Therefore, refresh active_i_buckets based on rank 0 as well.
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
self.active_i_buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
flattened_buckets = info[self.num_buckets + 1:]
flat_i = 0
for bucket_idx in range(self.num_buckets):
for bucket_idx in range(self.num_buckets):
for bucket_loc in range(self.bucket_sizes[bucket_idx]):
param_i = flattened_buckets[flat_i]
self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.active_i_buckets[bucket_idx][bucket_loc] = param_i
self.param_id_to_bucket[id(self.active_params[param_i])] = (bucket_idx, bucket_loc)
flat_i += 1
flat_i += 1
def create_hooks(self):
# Fallback hook that's only called at the end of backward.
# Used if you deliberately want to delay allreduces to the end, or to refresh the
# Used if you deliberately want to delay allreduces to the end, or to refresh the
# bucket structure that will be used to overlap communication with computation in later
# iterations.
def allreduce_params():
......@@ -308,9 +333,10 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event)
torch.cuda.current_stream().wait_event(self.reduction_event)
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets:
raise RuntimeError("In epilogue, next_bucket ({}) != num_buckets ({}). ".format(
......@@ -320,7 +346,7 @@ class DistributedDataParallel(Module):
for actual, expected in zip(self.buckets_ready_size, self.bucket_sizes):
if actual != expected:
raise RuntimeError("Some param buckets were not allreduced.")
self.grad_accs = []
for param in self.module.parameters():
......@@ -330,6 +356,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if not self._disable_allreduce:
if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between
......@@ -341,8 +370,8 @@ class DistributedDataParallel(Module):
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False
if self.custom_allreduce_triggers:
......@@ -359,82 +388,133 @@ class DistributedDataParallel(Module):
self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params)
self.callback_queued = True
else:
if not self.callback_queued:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True
self.callback_queued = True
self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc)
wrapper(param)
def allreduce_bucket(self, bucket):
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):
tensor = flatten(bucket)
tensor_to_allreduce = tensor
if force_default_stream:
bucket_stream = self.main_stream
else:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
tensor_to_allreduce = tensor
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.allreduce_always_fp32:
tensor_to_allreduce = tensor.float()
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
dist.all_reduce(tensor_to_allreduce)
if self.allreduce_different_streams and not force_default_stream:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams])
else:
dist.all_reduce(tensor_to_allreduce)
if self.gradient_average:
if self.gradient_predivide_factor != self.world_size:
if self.gradient_average:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce)
if not self.retain_allreduce_buffers:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(tensor, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced)
# I think we actually do need this here. After allreduce_bucket returns, tensor will
# eventually go out of scope and die, at which point it could otherwise be freed for
# further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.
tensor.record_stream(bucket_stream)
return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1):
allreduced = self.allreduce_bucket(bucket)
def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):
allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)
if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced
else:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
for view, grad in zip(unflatten(allreduced, bucket), bucket):
grad.data = view
# for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
# buf.copy_(synced)
def allreduce_fallback(self):
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads)
# If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless.
# this will only be done during the first backward pass, ignored by the
# training script, and overwritten in the next forward pass. So it's harmless.
if self.retain_allreduce_buffers:
self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i)
allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)
def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
......@@ -442,39 +522,46 @@ class DistributedDataParallel(Module):
raise RuntimeError("The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error.")
self.buckets[bucket_idx][bucket_loc] = param.grad.data
if self.retain_allreduce_buffers:
self.buckets[bucket_idx][bucket_loc] = param.grad
else:
self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event)
self.reduction_stream.wait_event(self.reduction_event)
with torch.cuda.stream(self.reduction_stream):
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1
# Reversing upstream's logic here, because we constructed our buckets based on
# the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
sorted_todo = sorted(self.ready_buckets_not_reduced)
for i in sorted_todo:
# Nothing can be reduced now
if i > self.next_bucket:
break
elif i == self.next_bucket:
self.allreduce_maybe_retain(self.buckets[i], i)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1
# Reversing upstream's logic here, because we constructed our buckets based on
# the order things were received during backward.
if len(self.ready_buckets_not_reduced) > 0:
sorted_todo = sorted(self.ready_buckets_not_reduced)
for i in sorted_todo:
# Nothing can be reduced now
if i > self.next_bucket:
break
elif i == self.next_bucket:
self.allreduce_maybe_retain(self.buckets[i], i)
self.ready_buckets_not_reduced.remove(i)
self.next_bucket += 1
else:
raise ValueError("i should always be >= next_bucket")
else:
self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self._disable_allreduce:
if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
......@@ -483,7 +570,7 @@ class DistributedDataParallel(Module):
# Forward has the authority to set needs_refresh to True, but only allreduce_params
# in backward has the authority to set needs_refresh to False.
# Parentheses are not necessary for correct order of operations, but make the intent clearer.
if ((not self.active_params) or
if ((not self.active_params) or
(len(param_list) != len(self.active_params)) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
......@@ -494,19 +581,59 @@ class DistributedDataParallel(Module):
self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.tmp_numels = [0, 0, 0]
self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
# self.buckets = [[None for _ in range(self.bucket_sizes[i])]
# for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_communicators:
self.bucket_pgs = self.allreduce_communicators[0]
self.bucket_streams = self.allreduce_communicators[1]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0
self.ready_buckets_not_reduced = set()
self.active_params = param_list
self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result
......@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda()
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group
self.channel_last = channel_last
self.fuse_relu = fuse_relu
def _specify_process_group(self, process_group):
self.process_group = process_group
......@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm):
def _specify_channel_last(self, channel_last):
self.channel_last = channel_last
def forward(self, input):
def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last:
if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
......@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last)
return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
......@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function):
@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False):
def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous()
world_size = 0
......@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std)
ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.process_group = process_group
ctx.channel_last = channel_last
ctx.world_size = world_size
ctx.fuse_relu = fuse_relu
if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias)
out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)
else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
......@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std = ctx.saved_tensors
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
process_group = ctx.process_group
channel_last = ctx.channel_last
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
fuse_relu = ctx.fuse_relu
grad_input = grad_z = grad_weight = grad_bias = None
if fuse_relu:
grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output?
if channel_last:
......@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function):
else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]:
if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None
if weight is None or not ctx.needs_input_grad[2]:
if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None
return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
......@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda(
int chunk_size,
at::Tensor noop_flag,
......@@ -72,6 +85,8 @@ void multi_tensor_novograd_cuda(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
......
......@@ -3,6 +3,9 @@
// CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
......@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
}
......@@ -9,6 +9,10 @@
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h"
......@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel(
}
}
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
};
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
......@@ -96,7 +187,7 @@ void fused_adam_cuda(
AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
p.data<accscalar_t>(),
......@@ -112,7 +203,7 @@ void fused_adam_cuda(
tsize,
(adamMode_t) mode,
decay);
)
);
} else {
using namespace at;
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
......@@ -135,3 +226,110 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].scalar_type() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_FLOAT_AND_HALF(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
} else {
if (tl_sz == 5) {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
} else {
DISPATCH_DOUBLE_AND_FLOAT(tensor_lists[3][0].scalar_type(), 0, "adam_cuda_mt_kernel",
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t_0, scalar_t_0>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
);
}
}
THCudaCheck(cudaGetLastError());
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.cuh"
#include "compat.h"
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE 512
#define ILP 4
/**
* Perform fused SGD on multiple buffers
* N: number of tensors
* tl[0] : gradients
* tl[1] : weights
* tl[2] : momentum buffers
* tl[3] : fp16 weights (if appropriate)
* wd : weight_decay (scalar)
* momentum : momentum (scalar)
* dampening : momentum dampening (scalar)
* lr : learning rate (scalar)
* nesterov : enable nesterov (bool)
* first run : necessary for proper momentum handling & init
* wd_after_momentum : apply weight decay _after_ momentum instead of before
**/
template<int N, typename T_grad, typename T_weight>
struct SGDFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<N>& tl,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size;
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size;
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size;
at::Half *model_weights_out = nullptr;
if(N == 4)
{
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
// Non-divergent exit condition for the __syncthreads
float incoming_grads[ILP];
float incoming_weights[ILP];
float incoming_moms[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
incoming_grads[ii] = 0;
incoming_weights[ii] = 0;
incoming_moms[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
incoming_weights[ii] = static_cast<float>(weight_in[i]);
incoming_moms[ii] = static_cast<float>(mom_in[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
// apply weight decay before momentum if necessary
if(wd != 0.f && !wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
if(momentum != 0.f)
{
if(!first_run)
incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
else // initialize momentums to current incoming grads
incoming_moms[ii] = incoming_grads[ii];
if(nesterov)
incoming_grads[ii] += momentum * incoming_moms[ii];
else
incoming_grads[ii] = incoming_moms[ii];
}
// Apply WD after momentum if desired
if(wd != 0.f && wd_after_momentum)
incoming_grads[ii] += wd * incoming_weights[ii];
// adjust the weight and write out
weight_in[i] += (-lr * incoming_grads[ii]);
// if necessary, write out an fp16 copy of the weights
if(N == 4)
model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
// also write out the new momentum
if(momentum != 0.f)
mom_in[i] = incoming_moms[ii];
}
}
}
}
};
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale)
{
auto num_tensors = tensor_lists.size();
auto grad_type = tensor_lists[0][0].scalar_type();
auto weight_type = tensor_lists[1][0].scalar_type();
if(num_tensors == 4)
for(int i = 0; i < tensor_lists[3].size(); i++)
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
"Additional output tensors should always be fp16.");
// We have 3 possibilities to handle here, in terms of
// grad_type, param_type, momentum_type, requires_fp16_copy
// 1. fp16, fp16, fp16, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// Case 1. fp16, fp16, fp16, No
if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Half &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, at::Half, at::Half>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 2. fp16, fp32, fp32, No
// else if (grad_type == at::ScalarType::Half &&
// weight_type == at::ScalarType::Float &&
// num_tensors == 3) {
// multi_tensor_apply<3>(
// BLOCK_SIZE,
// chunk_size,
// noop_flag,
// tensor_lists,
// SGDFunctor<3, at::Half, float>(),
// wd,
// momentum,
// dampening,
// lr,
// nesterov,
// first_run,
// wd_after_momentum);
// }
// Case 2. fp32, fp32, fp32, No
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 3)
{
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<3, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 3. fp16, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Half &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, at::Half, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
// Case 4. fp32, fp32, fp32, Yes
else if(grad_type == at::ScalarType::Float &&
weight_type == at::ScalarType::Float &&
num_tensors == 4)
{
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
SGDFunctor<4, float, float>(),
wd,
momentum,
dampening,
lr,
nesterov,
first_run,
wd_after_momentum,
scale);
}
else
{
AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
"gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
}
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type;
......@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
......@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
}
......@@ -591,6 +591,58 @@ template <
int PARALLEL_LOADS>
__global__ void batchnorm_forward_c_last_kernel(
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
const layerscalar_t* __restrict__ shift,
scalar_t* __restrict__ out,
const int reduction_size,
const int stride,
const bool fuse_relu) {
// tensor dimension (m,c)
// loop along m dimension
int inner_loop_stride = blockDim.y * gridDim.y;
// offset along m dimension
int m_offset = blockIdx.y * blockDim.y + threadIdx.y;
int c_offset = blockIdx.x * blockDim.x + threadIdx.x;
auto m_c = mean[c_offset];
auto inv_std_c = static_cast<accscalar_t>(inv_std[c_offset]);
auto w_c = weight == NULL ? accscalar_t(1.0) : static_cast<accscalar_t>(weight[c_offset]);
auto s_c = shift == NULL ? accscalar_t(0.0) : static_cast<accscalar_t>(shift[c_offset]);
int loop_count = 1 + (reduction_size - 1) / (inner_loop_stride * PARALLEL_LOADS);
int address_base = m_offset * stride + c_offset;
int address_increment = inner_loop_stride * stride;
for (int i = 0; i < loop_count; i++) {
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (fuse_relu && tmp <= accscalar_t(0.0) ? scalar_t(0.0) : static_cast<scalar_t>(tmp));
}
m_offset += inner_loop_stride;
address_base += address_increment;
}
}
}
// elementwise BN kernel
template <
typename scalar_t,
typename accscalar_t,
typename layerscalar_t,
int PARALLEL_LOADS>
__global__ void relu_backward_c_last_kernel(
const scalar_t* __restrict__ grad_output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ z,
const accscalar_t* __restrict__ mean,
const accscalar_t* __restrict__ inv_std,
const layerscalar_t* __restrict__ weight,
......@@ -619,9 +671,11 @@ __global__ void batchnorm_forward_c_last_kernel(
#pragma unroll
for (int j = 0; j < PARALLEL_LOADS; j++) {
if (c_offset < stride && m_offset < reduction_size) {
out[address_base] = static_cast<scalar_t>(
w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c
);
auto tmp = w_c * (static_cast<accscalar_t>(input[address_base]) - m_c ) * inv_std_c + s_c;
if (z != NULL) {
tmp += z[address_base];
}
out[address_base] = (tmp <= accscalar_t(0.0) ? scalar_t(0.0) : grad_output[address_base]);
}
m_offset += inner_loop_stride;
address_base += address_increment;
......@@ -1147,10 +1201,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at::Tensor batchnorm_forward_c_last_CUDA(
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const at::optional<at::Tensor> shift,
const bool fuse_relu) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
......@@ -1170,13 +1226,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
stride,
fuse_relu);
);
} else {
if (weight.has_value()) {
......@@ -1189,13 +1247,15 @@ at::Tensor batchnorm_forward_c_last_CUDA(
batchnorm_forward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
stride,
fuse_relu);
);
}
return out;
......@@ -1351,3 +1411,66 @@ at::Tensor batchnorm_backward_c_last_CUDA(
return grad_input;
}
at::Tensor relu_backward_c_last_CUDA(
const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift) {
const auto stride = input.size(input.ndimension()-1);
const auto reduction_size = input.numel() / stride;
at::Tensor out = at::empty_like(input);
dim3 block;
dim3 grid;
flexible_launch_configs(reduction_size, stride, block, grid);
auto stream = at::cuda::getCurrentCUDAStream();
if (input.scalar_type() == at::ScalarType::Half
&& weight.has_value() && weight.value().scalar_type() == at::ScalarType::Float) {
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, accscalar_t, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<accscalar_t>() : NULL,
shift.has_value() ? shift.value().data<accscalar_t>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
} else {
if (weight.has_value()) {
TORCH_CHECK(input.scalar_type() == weight.value().scalar_type(),
"input.scalar_type() is not supported with weight.scalar_type()");
}
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "batchnorm_forward",
using accscalar_t = at::acc_type<scalar_t_0, true>;
relu_backward_c_last_kernel<scalar_t_0, accscalar_t, scalar_t_0, ELEMENTS_PER_ITER>
<<<grid, block, 0, stream>>>(
grad_output.data<scalar_t_0>(),
input.data<scalar_t_0>(),
z.has_value() ? z.value().data<scalar_t_0>() : NULL,
mean.data<accscalar_t>(),
inv_std.data<accscalar_t>(),
weight.has_value() ? weight.value().data<scalar_t_0>() : NULL,
shift.has_value() ? shift.value().data<scalar_t_0>(): NULL,
out.data<scalar_t_0>(),
reduction_size,
stride);
);
}
return out;
}
......@@ -25,54 +25,6 @@ try:
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
cudnn.benchmark = True
def fast_collate(batch):
imgs = [img[0] for img in batch]
......@@ -90,24 +42,75 @@ def fast_collate(batch):
return tensor, targets
best_prec1 = 0
args = parser.parse_args()
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
def parse():
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
args = parser.parse_args()
return args
def main():
global best_prec1, args
args = parse()
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
......
......@@ -72,6 +72,7 @@ if "--cuda_ext" in sys.argv:
ext_modules.append(
CUDAExtension(name='amp_C',
sources=['csrc/amp_C_frontend.cpp',
'csrc/multi_tensor_sgd_kernel.cu',
'csrc/multi_tensor_scale_kernel.cu',
'csrc/multi_tensor_axpby_kernel.cu',
'csrc/multi_tensor_l2norm_kernel.cu',
......@@ -104,6 +105,58 @@ if "--cuda_ext" in sys.argv:
'-O3',
'--use_fast_math'] + version_ge_1_1}))
if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--bnp")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--bnp was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append(
CUDAExtension(name='bnp',
sources=['apex/contrib/csrc/groupbn/batch_norm.cu',
'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
extra_compile_args={'cxx': [] + version_ge_1_1,
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
'-D__CUDA_NO_HALF_CONVERSIONS__',
'-D__CUDA_NO_HALF2_OPERATORS__',
'-gencode',
'arch=compute_70,code=sm_70'] + version_ge_1_1}))
if "--xentropy" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--xentropy")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--xentropy was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
# Set up macros for forward/backward compatibility hack around
# https://github.com/pytorch/pytorch/commit/4404762d7dd955383acee92e6f06b48144a0742e
version_ge_1_1 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_1 = ['-DVERSION_GE_1_1']
ext_modules.append(
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=['csrc'],
extra_compile_args={'cxx': ['-O3'] + version_ge_1_1,
'nvcc':['-O3'] + version_ge_1_1}))
setup(
name='apex',
version='0.1',
......
......@@ -137,26 +137,6 @@ class TestTensorCasts(unittest.TestCase):
fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestDisabledCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=False)
common_init(self)
def test_disabled_linear(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
input_shape = (self.b, self.h)
for fn in [m, f]:
x = torch.randn(input_shape, dtype=torch.float).requires_grad_()
y = fn(x)
self.assertEqual(y.type(), FLOAT)
y.sum().backward()
self.assertEqual(x.grad.type(), FLOAT)
x = torch.randn(input_shape, dtype=torch.half).requires_grad_()
self.assertRaises(RuntimeError, fn, x)
# TODO: maybe more tests on disabled casting?
if __name__ == '__main__':
......
import unittest
import functools as ft
import itertools as it
from apex import amp
from apex.amp import _amp_state
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Parameter
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
try:
import amp_C
disabled = False
from apex.optimizers import FusedSGD as FusedSGD
except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True
class MyModel(torch.nn.Module):
def __init__(self, unique):
super(MyModel, self).__init__()
self.weight0 = Parameter(unique +
torch.arange(2, device='cuda', dtype=torch.float32))
self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
@staticmethod
def ops(input, weight0, weight1):
return ((input*(weight0.float()))*(weight1.float())).sum()
def forward(self, input):
return self.ops(input, self.weight0, self.weight1)
# Abandon all hope, ye who enter here.
# This is hands down the ugliest code I have ever written, but it succeeds in testing
# multiple models/optimizers/losses fairly thoroughly. Many of the different test cases
# require slightly divergent code in a way that seems near-impossible to genericize into a simple
# cross product or nested loops.
class TestMultipleModelsOptimizersLosses(unittest.TestCase):
def setUp(self):
self.x = torch.ones((2), device='cuda', dtype=torch.float32)
common_init(self)
def tearDown(self):
pass
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_2models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], optimizer = amp.initialize(
[model0, model1],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses1optimizer(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125)
reference_grads = []
for i in range(2):
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
loss0.backward()
loss1.backward()
reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()] +
[param.grad.data.clone() for param in model2.parameters()])
optimizer.step()
final_params = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 2)
elif which_backward == 1:
which_models = (1, 2)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 0.5},
{'params' : model2.parameters(), 'lr' : 0.125}],
momentum=0.125,
materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], optimizer = amp.initialize(
[model0, model1, model2],
optimizer,
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer.zero_grad()
loss0 = model0(self.x) + model2(self.x)
loss1 = model1(self.x) + model2(self.x)
with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 1:
inj_model = model1
elif which_model == 2:
inj_model = model2
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = amp.master_params(optimizer)
for param, reference_grad in zip(master_params, reference_grads[unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()),
"opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
unskipped += 1
optimizer.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
for model, master, reference in zip(
model_params,
amp.master_params(optimizer),
final_params):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_2models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25)
# Don't do it like this: reference_grads = [[]]*5
# because then it creates a list of 5 references to the same "[]" and appending
# to any of them effectively makes you append to all of them, which multiplies
# the resulting size of reference_grads by 5x and needless to say makes the test fail.
reference_grads = [[], [], [], [], []]
final_params = [None, None, None, None, None]
for i in range(2):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer0.step()
optimizer1.step()
final_params[0] = [param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
def what_got_skipped(which_iter, which_backward):
if which_iter == 0 and which_backward == 0:
return 1
if which_iter == 0 and which_backward == 1:
return 2
if which_iter == 1 and which_backward == 0:
return 3
if which_iter == 1 and which_backward == 1:
return 4
return 0
for which_iter in (0,1):
for which_backward in (0,1):
model0 = MyModel(1)
model1 = MyModel(2)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125)
optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25)
for i in range(3):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
loss0.backward()
loss1.backward()
if i != which_iter:
reference_grads[what_got_skipped(which_iter, which_backward)].append(
[param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
if i == which_iter:
if which_backward == 0:
optimizer1.step()
else:
optimizer0.step()
else:
optimizer0.step()
optimizer1.step()
final_params[what_got_skipped(which_iter, which_backward)] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
else:
iters = 2
model0 = MyModel(1)
model1 = MyModel(2)
models = [model0, model1]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],
momentum=0.125, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1], [optimizer0, optimizer1] = amp.initialize(
[model0, model1],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x)
loss1 = model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if inject_inf_loc == "fp32":
model0.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model0.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if inject_inf_loc == "fp32":
model1.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
model1.weight1.grad[0] = float('inf')
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
if i != inject_inf:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
@unittest.skipIf(disabled, "amp_C is unavailable")
def test_3models2losses2optimizers(self):
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5)
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25)
# Again, can't do this: reference_grads = [[]]*9
reference_grads = [[], [], [], [], [], [], [], [], []]
final_params = [None, None, None, None, None, None, None, None, None]
for i in range(2):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
loss0.backward()
loss1.backward()
reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
optimizer0.step()
optimizer1.step()
final_params[0] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
def what_got_skipped(which_iter, which_backward, which_model):
if which_iter == 0:
if which_backward == 0:
if which_model == 0:
return 1
if which_model == 1:
return 2
if which_backward == 1:
if which_model == 2:
return 3
if which_model == 1:
return 4
if which_iter == 1:
if which_backward == 0:
if which_model == 0:
return 5
if which_model == 1:
return 6
if which_backward == 1:
if which_model == 2:
return 7
if which_model == 1:
return 8
return 0
for which_iter in (0,1):
for which_backward in (0,1):
if which_backward == 0:
which_models = (0,1)
if which_backward == 1:
which_models = (2,1)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5)
optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25)
for i in range(3):
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
loss0.backward()
loss1.backward()
if i != which_iter:
reference_grads[what_got_skipped(which_iter,
which_backward, which_model)].append(
[param.grad.data.clone() for param in model0.parameters()] +
[param.grad.data.clone() for param in model1.parameters()])
if i == which_iter:
if which_backward == 0:
# if which_model == 0:
optimizer1.step()
# if which_model == 1:
# optimizer1.step()
if which_backward == 1:
# if which_model == 2:
# optimizer0.step()
# if which_model == 1:
continue
else:
optimizer0.step()
optimizer1.step()
final_params[what_got_skipped(which_iter, which_backward, which_model)] = \
[param.data.clone() for param in model0.parameters()] + \
[param.data.clone() for param in model1.parameters()] + \
[param.data.clone() for param in model2.parameters()]
for materialize_master_grads in (False, True):
for opt_level in ("O0", "O1", "O2", "O3"):
for how_to_zero in ("none", "model", "optimizer"):
for use_multiple_loss_scalers in (False, True):
if opt_level == "O1" or opt_level == "O2":
inject_inf_iters = (-1, 0, 1)
else:
inject_inf_iters = (-1,)
for inject_inf in inject_inf_iters:
if inject_inf >= 0:
inject_inf_locs = ("fp16", "fp32")
which_backwards = (0, 1)
else:
inject_inf_locs = ("fdsa",)
which_backwards = (None,)
for inject_inf_loc in inject_inf_locs:
for which_backward in which_backwards:
if use_multiple_loss_scalers:
num_losses = 2
loss_ids = [0, 1]
else:
num_losses = 1
loss_ids = [0, 0]
if inject_inf >= 0:
iters = 3
if which_backward == 0:
which_models = (0, 1)
elif which_backward == 1:
which_models = (2, 1)
else:
iters = 2
which_models = (None,)
for which_model in which_models:
model0 = MyModel(1)
model1 = MyModel(2)
model2 = MyModel(3)
models = [model0, model1, model2]
optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
{'params' : model1.parameters(), 'lr' : 1.0}],
momentum=0.5, materialize_master_grads=materialize_master_grads)
optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],
momentum=0.25, materialize_master_grads=materialize_master_grads)
_amp_state.allow_incoming_model_not_fp32 = True
[model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
[model0, model1, model2],
[optimizer0, optimizer1],
opt_level=opt_level,
verbosity=0,
cast_model_type=False,
num_losses=num_losses)
_amp_state.allow_incoming_model_not_fp32 = False
_amp_state.loss_scalers[0]._loss_scale = 4.0
if use_multiple_loss_scalers:
_amp_state.loss_scalers[1]._loss_scale = 16.0
unskipped = 0
for i in range(iters):
if how_to_zero == "none":
for model in models:
for param in model.parameters():
param.grad = None
elif how_to_zero == "model":
for model in models:
model.zero_grad()
else:
optimizer0.zero_grad()
optimizer1.zero_grad()
loss0 = model0(self.x) + model1(self.x)
loss1 = model2(self.x) + model1(self.x)
with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 0:
if which_model == 0:
inj_model = model0
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 0")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
scaled_loss.backward()
if i == inject_inf and which_backward == 1:
if which_model == 2:
inj_model = model2
elif which_model == 1:
inj_model = model1
else:
raise RuntimeError(which_model + " invalid for loss 1 ")
if inject_inf_loc == "fp32":
inj_model.weight0.grad[0] = float('inf')
elif inject_inf_loc == "fp16":
inj_model.weight1.grad[0] = float('inf')
if i != inject_inf:
master_params = list(amp.master_params(optimizer0)) + \
list(amp.master_params(optimizer1))
for param, reference_grad in zip(master_params,
reference_grads[what_got_skipped(inject_inf,
which_backward, which_model)][unskipped]):
if opt_level == "O2" and not materialize_master_grads:
continue
else:
self.assertTrue(torch.allclose(param.grad.float(), reference_grad.float()))
unskipped += 1
optimizer0.step()
optimizer1.step()
model_params = [p for p in model0.parameters()] + \
[p for p in model1.parameters()] + \
[p for p in model2.parameters()]
master_params = [p for p in amp.master_params(optimizer0)] + \
[p for p in amp.master_params(optimizer1)]
# print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
for model, master, reference in zip(
model_params,
master_params,
final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
self.assertTrue(torch.allclose(model, reference))
self.assertTrue(torch.allclose(model, master.to(model.dtype)))
if opt_level == "O1":
_amp_state.handle._deactivate()
if __name__ == '__main__':
unittest.main()
......@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self):
pass
def gen_param_optim(self, tensors, adam_option):
def gen_param_optim(self, tensors, ref_adam_option, tst_adam_option=None):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adam(ref_param, **adam_option)
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **adam_option)
ref_optim = torch.optim.Adam(ref_param, **ref_adam_option)
if tst_adam_option:
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **tst_adam_option)
else:
tst_optim = apex.optimizers.FusedAdam_v1(tst_param, **ref_adam_option)
return (ref_param, tst_param, ref_optim, tst_optim)
......@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase):
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
max_abs_diff_p = (p_ref - p_tst.type(p_ref.type())).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst.type(p_ref.type())) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
......@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_multi_tensor(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
ref_adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tst_adam_option = dict(ref_adam_option, **{'use_mt':True})
tensors = []
fp16_params = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
fp16_params.append(torch.nn.Parameter(tensors[-1].clone().half()))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, ref_adam_option, tst_adam_option)
for i in range(self.iters):
half_grads = self.gen_mixed_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step(grads=half_grads, output_params=fp16_params)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
fp16_params)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
......
#!/bin/bash
DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
# DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
# DATADIR="/opt/home/apex/examples/imagenet/"
cp ../common/* .
bash run_test.sh single_gpu $1 $DATADIR yes
bash run_test.sh single_gpu $1
......@@ -35,8 +35,9 @@ class Model(Module):
model = Model()
# model = DDP(model, message_size=1, gradient_predivide_factor=8.0)
model = DDP(model, delay_allreduce=True)
# model = DDP(model, delay_allreduce=True)
# model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)
x = torch.cuda.FloatTensor(4096*4096)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment