Commit 4d6ed501 authored by Deyu Fu's avatar Deyu Fu
Browse files

Merge branch 'multi_tensor_sgd' into deyuf/fused_optimizer_v2

parents 690b1f71 9f64bf27
import torch
import xentropy_cuda
class SoftmaxCrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
losses, max_log_sum_exp = xentropy_cuda.forward(
logits, labels, smoothing, half_to_float)
losses.masked_fill_(labels==padding_idx, 0)
ctx.save_for_backward(logits, max_log_sum_exp, labels,
torch.FloatTensor([smoothing]),
torch.LongTensor([padding_idx]))
return losses
@staticmethod
def backward(ctx, grad_loss):
logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors
if not grad_loss.is_contiguous():
grad_loss = grad_loss.contiguous()
grad_loss.masked_fill_(labels==padding_idx.item(), 0)
grad_logits = xentropy_cuda.backward(
grad_loss.contiguous(), logits, max_log_sum_exp,
labels, smoothing.item())
return grad_logits, None, None, None, None
from .sgd import FusedSGD from .fused_sgd import FusedSGD
from .novograd import FusedNovoGrad from .novograd import FusedNovoGrad
from .fused_adam_v1 import FusedAdam_v1 from .fused_adam_v1 import FusedAdam_v1
from .adam import FusedAdam from .adam import FusedAdam
#from .sgd import FusedSGD
from .fp16_optimizer import FP16_Optimizer from .fp16_optimizer import FP16_Optimizer
...@@ -2,8 +2,9 @@ import types ...@@ -2,8 +2,9 @@ import types
import torch import torch
import importlib import importlib
class FusedAdam_v1(torch.optim.Optimizer): from ..multi_tensor_apply import multi_tensor_applier
class FusedAdam_v1(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via """Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
``python setup.py install --cuda_ext --cpp_ext``. ``python setup.py install --cuda_ext --cpp_ext``.
...@@ -25,6 +26,8 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -25,6 +26,8 @@ class FusedAdam_v1(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False) second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization: .. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980 https://arxiv.org/abs/1412.6980
...@@ -35,10 +38,21 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -35,10 +38,21 @@ class FusedAdam_v1(torch.optim.Optimizer):
def __init__(self, params, def __init__(self, params,
lr=1e-3, bias_correction = True, lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False, betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False): weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda") fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad: if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.') raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
...@@ -66,6 +80,12 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -66,6 +80,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
if closure is not None: if closure is not None:
loss = closure() loss = closure()
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms
if grads is None: if grads is None:
grads_group = [None]*len(self.param_groups) grads_group = [None]*len(self.param_groups)
# backward compatibility # backward compatibility
...@@ -105,6 +125,12 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -105,6 +125,12 @@ class FusedAdam_v1(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0 bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group): for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients #note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None: if p.grad is None and grad is None:
...@@ -130,6 +156,14 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -130,6 +156,14 @@ class FusedAdam_v1(torch.optim.Optimizer):
state['step'] += 1 state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad]
if output_param is not None:
pl.append(out_p)
for tl, t in zip(tensorlists, pl):
tl.append(t)
else:
fused_adam_cuda.adam(p.data, fused_adam_cuda.adam(p.data,
out_p, out_p,
exp_avg, exp_avg,
...@@ -144,4 +178,21 @@ class FusedAdam_v1(torch.optim.Optimizer): ...@@ -144,4 +178,21 @@ class FusedAdam_v1(torch.optim.Optimizer):
self.eps_mode, self.eps_mode,
bias_correction, bias_correction,
group['weight_decay']) group['weight_decay'])
if self._use_multi_tensor:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss return loss
import torch
from torch.optim.optimizer import Optimizer, required
from apex.multi_tensor_apply import multi_tensor_applier
class FusedSGD(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False,
wd_after_momentum=False,
materialize_master_grads=True):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(FusedSGD, self).__init__(params, defaults)
self.wd_after_momentum = wd_after_momentum
self.materialize_master_grads = materialize_master_grads
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
if multi_tensor_applier.available:
import amp_C
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
else:
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')
def __setstate__(self, state):
super(FusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def get_momentums(self, params):
momentums = []
first_run = True
for p in params:
param_state = self.state[p]
# torch.optim.SGD initializes momentum in the main loop, we have
# to do it here, and track whether or not we've done so, so that
# momentum application can be skipped in the main kernel.
if 'momentum_buffer' not in param_state:
first_run = True
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
momentums.append(buf)
else:
first_run = False
momentums.append(param_state['momentum_buffer'])
return momentums, first_run
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
explicit_master_params = (hasattr(self, "_amp_stash") and
hasattr(self._amp_stash, "fp32_from_fp16_groups"))
for gid, group in enumerate(self.param_groups):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
# For each group, there are 3 possible combinations we need to consider:
# grad_type, param_to_update_type, momentum_type, requires_fp16_model_copy
# 1. fp16, fp16, fp16, No
# 2. fp32, fp32, fp32, No
# 3. fp16, fp32, fp32, Yes
first_runs = [True, True]
# I think a bit of code divergence in exchange for naming clarity is worthwhile
if explicit_master_params:
stash = self._amp_stash
fp32_params = [p for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_grads = [p.grad for p in stash.fp32_from_fp32_groups[gid] if p.grad is not None]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
if self.materialize_master_grads:
fp16_model_params = [p for i, p in enumerate(
stash.fp16_groups[gid]) if stash.fp32_from_fp16_groups[gid][i].grad is not None]
fp32_from_fp16_grads = [p.grad for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for p in stash.fp32_from_fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp32_from_fp16_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
else:
fp16_model_params = [p for p in stash.fp16_groups[gid] if p.grad is not None]
fp16_model_grads = [p.grad for p in stash.fp16_groups[gid] if p.grad is not None]
fp32_from_fp16_params = [p for i, p in enumerate(
stash.fp32_from_fp16_groups[gid]) if stash.fp16_groups[gid][i].grad is not None]
fp32_from_fp16_momentums, first_runs[0] = self.get_momentums(fp32_from_fp16_params)
fp16_set = [fp16_model_grads, fp32_from_fp16_params,
fp32_from_fp16_momentums, fp16_model_params]
launch_sets= [fp16_set, [fp32_grads, fp32_params, fp32_momentums]]
else:
fp16_params = [p for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_grads = [p.grad for p in group['params'] if (p.dtype == torch.float16 and p.grad is not None)]
fp16_momentums, first_runs[0] = self.get_momentums(fp16_params)
fp32_params = [p for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_grads = [p.grad for p in group['params'] if (p.dtype == torch.float32 and p.grad is not None)]
fp32_momentums, first_runs[1] = self.get_momentums(fp32_params)
launch_sets = [[fp16_grads, fp16_params, fp16_momentums],
[fp32_grads, fp32_params, fp32_momentums]]
for s, (launch_set, first_run) in enumerate(zip(launch_sets, first_runs)):
assert len(launch_set[0]) == len(launch_set[1])
assert len(launch_set[0]) == len(launch_set[2])
if len(launch_set[0]) > 0:
multi_tensor_applier(
self.multi_tensor_sgd,
self._dummy_overflow_buf,
launch_set,
weight_decay,
momentum,
dampening,
group['lr'],
nesterov,
first_run,
self.wd_after_momentum,
1.0/self.most_recent_scale)
self.most_recent_scale = 1.0
self.scale_set_by_backward = False
return loss
...@@ -167,8 +167,12 @@ class DistributedDataParallel(Module): ...@@ -167,8 +167,12 @@ class DistributedDataParallel(Module):
allreduce_trigger_params=None, allreduce_trigger_params=None,
retain_allreduce_buffers=False, retain_allreduce_buffers=False,
allreduce_always_fp32=False, allreduce_always_fp32=False,
num_allreduce_streams=1,
allreduce_communicators=None,
gradient_average=True, gradient_average=True,
gradient_predivide_factor=1.0): gradient_predivide_factor=1.0,
gradient_average_split_factor=None,
prof=False):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
# Backward/forward compatibility around # Backward/forward compatibility around
...@@ -186,6 +190,19 @@ class DistributedDataParallel(Module): ...@@ -186,6 +190,19 @@ class DistributedDataParallel(Module):
self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False self.warn_on_half = True if self._backend == self.backend_enum_holder.GLOO else False
self.prof = prof
self.allreduce_different_streams = (num_allreduce_streams > 1)
self.num_allreduce_streams = num_allreduce_streams
self.allreduce_communicators = allreduce_communicators
if self.allreduce_communicators:
assert len(allreduce_communicators[0]) == num_allreduce_streams
assert len(allreduce_communicators[0]) == len(allreduce_communicators[1])
assert self.allreduce_different_streams
if self.allreduce_different_streams and delay_allreduce:
raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
...@@ -206,8 +223,10 @@ class DistributedDataParallel(Module): ...@@ -206,8 +223,10 @@ class DistributedDataParallel(Module):
self.delay_allreduce = delay_allreduce self.delay_allreduce = delay_allreduce
self.message_size = message_size self.message_size = message_size
self.reduction_stream = torch.cuda.Stream() self.main_stream = torch.cuda.current_stream()
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False)
self.bucket_streams = []
self.bucket_events = []
self.module = module self.module = module
...@@ -236,15 +255,21 @@ class DistributedDataParallel(Module): ...@@ -236,15 +255,21 @@ class DistributedDataParallel(Module):
def __setstate__(self, state): def __setstate__(self, state):
super(DistributedDataParallel, self).__setstate__(state) super(DistributedDataParallel, self).__setstate__(state)
self.reduction_stream = torch.cuda.Stream() if self.allreduce_different_streams and delay_allreduce:
self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) raise ValueError("self.allreduce_different_streams may only be used if delay_allreduce=False.")
if self.delay_allreduce:
self.needs_refresh = True
self.bucket_streams = []
self.bucket_events = []
def __getstate__(self): def __getstate__(self):
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
if self._backend != self.backend_enum_holder.NCCL: if self._backend != self.backend_enum_holder.NCCL:
del attrs['self.reduction_stream'] del attrs['self.bucket_streams']
del attrs['self.reduction_event'] del attrs['self.bucket_events']
return attrs return attrs
def enable_allreduce(self): def enable_allreduce(self):
...@@ -308,8 +333,9 @@ class DistributedDataParallel(Module): ...@@ -308,8 +333,9 @@ class DistributedDataParallel(Module):
def overlapping_backward_epilogue(): def overlapping_backward_epilogue():
self.reduction_stream.record_event(self.reduction_event) for stream, event in zip(self.bucket_streams, self.bucket_events):
torch.cuda.current_stream().wait_event(self.reduction_event) stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
# Sanity checks that all the buckets were kicked off # Sanity checks that all the buckets were kicked off
if self.next_bucket != self.num_buckets: if self.next_bucket != self.num_buckets:
...@@ -330,6 +356,9 @@ class DistributedDataParallel(Module): ...@@ -330,6 +356,9 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook")
if not self._disable_allreduce: if not self._disable_allreduce:
if self.delay_allreduce or self.needs_refresh: if self.delay_allreduce or self.needs_refresh:
# TODO: How do we want to handle multiple backward passes between # TODO: How do we want to handle multiple backward passes between
...@@ -370,14 +399,44 @@ class DistributedDataParallel(Module): ...@@ -370,14 +399,44 @@ class DistributedDataParallel(Module):
self.comm_ready_buckets(param) self.comm_ready_buckets(param)
if self.prof:
torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
wrapper(param) wrapper(param)
def allreduce_bucket(self, bucket):
def _stream_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_streams[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_streams[0]
def _event_this_bucket(self, bucket_idx):
if self.allreduce_different_streams:
return self.bucket_events[bucket_idx%self.num_allreduce_streams]
else:
return self.bucket_events[0]
def allreduce_bucket(self, bucket, bucket_idx, force_default_stream):
tensor = flatten(bucket) tensor = flatten(bucket)
if force_default_stream:
bucket_stream = self.main_stream
else:
bucket_stream = self._stream_this_bucket(bucket_idx)
bucket_event = self._event_this_bucket(bucket_idx)
torch.cuda.current_stream().record_event(bucket_event)
bucket_stream.wait_event(bucket_event)
with torch.cuda.stream(bucket_stream):
# self.main_stream.wait_stream(torch.cuda.current_stream())
# torch.cuda.synchronize()
tensor_to_allreduce = tensor tensor_to_allreduce = tensor
if self.allreduce_always_fp32: if self.allreduce_always_fp32:
...@@ -386,38 +445,57 @@ class DistributedDataParallel(Module): ...@@ -386,38 +445,57 @@ class DistributedDataParallel(Module):
if self.gradient_predivide_factor != 1.0: if self.gradient_predivide_factor != 1.0:
tensor_to_allreduce.mul_(1./self.gradient_predivide_factor) tensor_to_allreduce.mul_(1./self.gradient_predivide_factor)
if self.allreduce_different_streams and not force_default_stream:
dist.all_reduce(tensor_to_allreduce, group=self.bucket_pgs[bucket_idx%self.num_allreduce_streams])
else:
dist.all_reduce(tensor_to_allreduce) dist.all_reduce(tensor_to_allreduce)
if self.gradient_average: if self.gradient_average:
if self.gradient_predivide_factor != self.world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size) tensor_to_allreduce.mul_(self.gradient_predivide_factor/self.world_size)
if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce:
tensor.copy_(tensor_to_allreduce) tensor.copy_(tensor_to_allreduce)
if not self.retain_allreduce_buffers:
if multi_tensor_applier.available:
multi_tensor_applier(
self.multi_tensor_scale,
self._overflow_buf,
[unflatten(tensor, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(tensor, bucket)):
buf.copy_(synced)
# I think we actually do need this here. After allreduce_bucket returns, tensor will
# eventually go out of scope and die, at which point it could otherwise be freed for
# further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.
tensor.record_stream(bucket_stream)
return tensor return tensor
def allreduce_maybe_retain(self, bucket, bucket_idx=-1): def allreduce_maybe_retain(self, bucket, bucket_idx, force_default_stream=False):
allreduced = self.allreduce_bucket(bucket) allreduced = self.allreduce_bucket(bucket, bucket_idx, force_default_stream)
if self.retain_allreduce_buffers: if self.retain_allreduce_buffers:
if self.allreduce_buffers[bucket_idx] is not None: if self.allreduce_buffers[bucket_idx] is not None:
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"allreduce buffer. This is almost certainly an error.") "allreduce buffer. This is almost certainly an error.")
self.allreduce_buffers[bucket_idx] = allreduced self.allreduce_buffers[bucket_idx] = allreduced
else: for view, grad in zip(unflatten(allreduced, bucket), bucket):
if multi_tensor_applier.available: grad.data = view
multi_tensor_applier( # for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
self.multi_tensor_scale, # buf.copy_(synced)
self._overflow_buf,
[unflatten(allreduced, bucket), bucket],
1.0)
else:
for buf, synced in zip(bucket, unflatten(allreduced, bucket)):
buf.copy_(synced)
def allreduce_fallback(self): def allreduce_fallback(self):
for stream, event in zip(self.bucket_streams, self.bucket_events):
stream.record_event(event)
torch.cuda.current_stream().wait_event(event)
if self.retain_allreduce_buffers:
grads = [param.grad for param in self.module.parameters() if param.grad is not None]
else:
grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] grads = [param.grad.data for param in self.module.parameters() if param.grad is not None]
split_buckets = split_half_float_double(grads) split_buckets = split_half_float_double(grads)
...@@ -429,12 +507,14 @@ class DistributedDataParallel(Module): ...@@ -429,12 +507,14 @@ class DistributedDataParallel(Module):
self.allreduce_buffers = [None for _ in range(len(split_buckets))] self.allreduce_buffers = [None for _ in range(len(split_buckets))]
for i, bucket in enumerate(split_buckets): for i, bucket in enumerate(split_buckets):
allreduced = self.allreduce_maybe_retain(bucket, i) allreduced = self.allreduce_maybe_retain(bucket, i, force_default_stream=True)
def comm_ready_buckets(self, param): def comm_ready_buckets(self, param):
# Need to do this in every hook for compatibility with Ruberry's streaming backward PR. # Need to do this in every hook for compatibility with Ruberry's streaming backward PR.
# self.reduction_stream.wait_stream(torch.cuda.current_stream()) # self.reduction_stream.wait_stream(torch.cuda.current_stream())
if self.prof:
torch.cuda.nvtx.range_push("comm_ready_buckets")
bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)] bucket_idx, bucket_loc = self.param_id_to_bucket[id(param)]
...@@ -442,14 +522,15 @@ class DistributedDataParallel(Module): ...@@ -442,14 +522,15 @@ class DistributedDataParallel(Module):
raise RuntimeError("The backward pass is attempting to replace an already-filled " raise RuntimeError("The backward pass is attempting to replace an already-filled "
"bucket slot. This is almost certainly an error.") "bucket slot. This is almost certainly an error.")
if self.retain_allreduce_buffers:
self.buckets[bucket_idx][bucket_loc] = param.grad
else:
self.buckets[bucket_idx][bucket_loc] = param.grad.data self.buckets[bucket_idx][bucket_loc] = param.grad.data
self.buckets_ready_size[bucket_idx] += 1 self.buckets_ready_size[bucket_idx] += 1
if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]: if self.buckets_ready_size[bucket_idx] == self.bucket_sizes[bucket_idx]:
if bucket_idx == self.next_bucket: if bucket_idx == self.next_bucket:
torch.cuda.current_stream().record_event(self.reduction_event)
self.reduction_stream.wait_event(self.reduction_event)
with torch.cuda.stream(self.reduction_stream):
self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx) self.allreduce_maybe_retain(self.buckets[bucket_idx], bucket_idx)
self.next_bucket += 1 self.next_bucket += 1
...@@ -471,10 +552,16 @@ class DistributedDataParallel(Module): ...@@ -471,10 +552,16 @@ class DistributedDataParallel(Module):
else: else:
self.ready_buckets_not_reduced.add(bucket_idx) self.ready_buckets_not_reduced.add(bucket_idx)
if self.prof:
torch.cuda.nvtx.range_pop()
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
result = self.module(*inputs, **kwargs) result = self.module(*inputs, **kwargs)
if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self._disable_allreduce: if not self._disable_allreduce:
if not self.delay_allreduce: if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] param_list = [param for param in self.module.parameters() if param.requires_grad]
...@@ -496,9 +583,46 @@ class DistributedDataParallel(Module): ...@@ -496,9 +583,46 @@ class DistributedDataParallel(Module):
self.bucket_sizes = [] self.bucket_sizes = []
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.param_id_to_bucket = {} self.param_id_to_bucket = {}
self.bucket_pgs = []
self.bucket_streams = []
self.bucket_events = []
else: else:
# self.buckets = [[None for _ in range(self.bucket_sizes[i])]
# for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])] self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)] for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_communicators:
self.bucket_pgs = self.allreduce_communicators[0]
self.bucket_streams = self.allreduce_communicators[1]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)] self.allreduce_buffers = [None for _ in range(self.num_buckets)]
...@@ -509,4 +633,7 @@ class DistributedDataParallel(Module): ...@@ -509,4 +633,7 @@ class DistributedDataParallel(Module):
self.callback_queued = False self.callback_queued = False
if self.prof:
torch.cuda.nvtx.range_pop()
return result return result
...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -55,10 +55,11 @@ class SyncBatchNorm(_BatchNorm):
>>> inp = torch.randn(10, 14, 14, 100).cuda() >>> inp = torch.randn(10, 14, 14, 100).cuda()
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group self.process_group = process_group
self.channel_last = channel_last self.channel_last = channel_last
self.fuse_relu = fuse_relu
def _specify_process_group(self, process_group): def _specify_process_group(self, process_group):
self.process_group = process_group self.process_group = process_group
...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm): ...@@ -66,11 +67,11 @@ class SyncBatchNorm(_BatchNorm):
def _specify_channel_last(self, channel_last): def _specify_channel_last(self, channel_last):
self.channel_last = channel_last self.channel_last = channel_last
def forward(self, input): def forward(self, input, z = None):
# if input.dim() == 2, we switch to channel_last for efficient memory accessing # if input.dim() == 2, we switch to channel_last for efficient memory accessing
channel_last = self.channel_last if input.dim() != 2 else True channel_last = self.channel_last if input.dim() != 2 else True
if not self.training and self.track_running_stats and not channel_last: if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
# fall back to pytorch implementation for inference # fall back to pytorch implementation for inference
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps) return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else: else:
...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm): ...@@ -81,4 +82,4 @@ class SyncBatchNorm(_BatchNorm):
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: else:
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
return SyncBatchnormFunction.apply(input, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, channel_last) return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)
...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp ...@@ -7,7 +7,7 @@ from apex.parallel import ReduceOp
class SyncBatchnormFunction(Function): class SyncBatchnormFunction(Function):
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False): def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, track_running_stats = True, momentum = 1.0, process_group = None, channel_last = False, fuse_relu = False):
torch.cuda.nvtx.range_push("sync_BN_fw") torch.cuda.nvtx.range_push("sync_BN_fw")
input = input.contiguous() input = input.contiguous()
world_size = 0 world_size = 0
...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function): ...@@ -53,13 +53,14 @@ class SyncBatchnormFunction(Function):
mean = running_mean.data mean = running_mean.data
inv_std = 1.0 / torch.sqrt(running_variance.data + eps) inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
ctx.save_for_backward(input, weight, mean, inv_std) ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
ctx.process_group = process_group ctx.process_group = process_group
ctx.channel_last = channel_last ctx.channel_last = channel_last
ctx.world_size = world_size ctx.world_size = world_size
ctx.fuse_relu = fuse_relu
if channel_last: if channel_last:
out = syncbn.batchnorm_forward_c_last(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward_c_last(input, z, mean, inv_std, weight, bias, fuse_relu)
else: else:
out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias) out = syncbn.batchnorm_forward(input, mean, inv_std, weight, bias)
...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function): ...@@ -73,11 +74,17 @@ class SyncBatchnormFunction(Function):
# mini batch mean & var are calculated by forward path. # mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0) # mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0) # var = 1./N*np.sum((h-mu)**2, axis = 0)
saved_input, weight, mean, inv_std = ctx.saved_tensors saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
process_group = ctx.process_group process_group = ctx.process_group
channel_last = ctx.channel_last channel_last = ctx.channel_last
world_size = ctx.world_size world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None fuse_relu = ctx.fuse_relu
grad_input = grad_z = grad_weight = grad_bias = None
if fuse_relu:
grad_output = syncbn.relu_bw_c_last(grad_output, saved_input, z, mean, inv_std, weight, bias)
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
grad_z = grad_output.clone()
# TODO(jie): why do I have to clone here? life time of grad_output? # TODO(jie): why do I have to clone here? life time of grad_output?
if channel_last: if channel_last:
...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function): ...@@ -100,11 +107,11 @@ class SyncBatchnormFunction(Function):
else: else:
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu) grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
if weight is None or not ctx.needs_input_grad[1]: if weight is None or not ctx.needs_input_grad[2]:
grad_weight = None grad_weight = None
if weight is None or not ctx.needs_input_grad[2]: if weight is None or not ctx.needs_input_grad[3]:
grad_bias = None grad_bias = None
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None, None return grad_input, grad_z, grad_weight, grad_bias, None, None, None, None, None, None, None, None
...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda( ...@@ -6,6 +6,19 @@ void multi_tensor_scale_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
float scale); float scale);
void multi_tensor_sgd_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
float wd,
float momentum,
float dampening,
float lr,
bool nesterov,
bool first_run,
bool wd_after_momentum,
float scale);
void multi_tensor_axpby_cuda( void multi_tensor_axpby_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
...@@ -72,6 +85,8 @@ void multi_tensor_novograd_cuda( ...@@ -72,6 +85,8 @@ void multi_tensor_novograd_cuda(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors"); "Fused overflow check + scale for a list of contiguous tensors");
m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
"Fused SGD optimizer for list of contiguous tensors");
m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda, m.def("multi_tensor_axpby", &multi_tensor_axpby_cuda,
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
// CUDA forward declaration // CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
} }
This diff is collapsed.
This diff is collapsed.
...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input); ...@@ -55,10 +55,12 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype) // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input, at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean, const at::Tensor mean,
const at::Tensor inv_std, const at::Tensor inv_std,
const at::optional<at::Tensor> weight, const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift); const at::optional<at::Tensor> shift,
const bool fuse_relu);
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias} // backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
// grad_output/input should have identical data type; // grad_output/input should have identical data type;
...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output, ...@@ -82,6 +84,15 @@ at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor mean_dy, const at::Tensor mean_dy,
const at::Tensor mean_dy_xmu); const at::Tensor mean_dy_xmu);
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
const at::Tensor input,
const at::optional<at::Tensor> z,
const at::Tensor mean,
const at::Tensor inv_std,
const at::optional<at::Tensor> weight,
const at::optional<at::Tensor> shift);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance"); m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance"); m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -92,4 +103,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc"); m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc"); m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc"); m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
} }
This diff is collapsed.
...@@ -25,89 +25,92 @@ try: ...@@ -25,89 +25,92 @@ try:
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
model_names = sorted(name for name in models.__dict__
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
def parse():
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__") if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])) and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=model_names, choices=model_names,
help='model architecture: ' + help='model architecture: ' +
' | '.join(model_names) + ' | '.join(model_names) +
' (default: resnet18)') ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)') help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N', parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run') help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size per process (default: 256)') metavar='N', help='mini-batch size per process (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.') metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum') help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)') metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH', parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set') help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--prof', default=-1, type=int, parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.') help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true') parser.add_argument('--deterministic', action='store_true')
parser.add_argument("--local_rank", default=0, type=int) parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument('--sync_bn', action='store_true', parser.add_argument('--sync_bn', action='store_true',
help='enabling apex sync BN.') help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str) parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None)
args = parser.parse_args()
cudnn.benchmark = True return args
def fast_collate(batch):
imgs = [img[0] for img in batch]
targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
w = imgs[0].size[0]
h = imgs[0].size[1]
tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
for i, img in enumerate(imgs):
nump_array = np.asarray(img, dtype=np.uint8)
if(nump_array.ndim < 3):
nump_array = np.expand_dims(nump_array, axis=-1)
nump_array = np.rollaxis(nump_array, 2)
tensor[i] += torch.from_numpy(nump_array)
return tensor, targets
best_prec1 = 0 def main():
args = parser.parse_args() global best_prec1, args
print("opt_level = {}".format(args.opt_level)) args = parse()
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("opt_level = {}".format(args.opt_level))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
if args.deterministic: cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False cudnn.benchmark = False
cudnn.deterministic = True cudnn.deterministic = True
torch.manual_seed(args.local_rank) torch.manual_seed(args.local_rank)
torch.set_printoptions(precision=10) torch.set_printoptions(precision=10)
def main():
global best_prec1, args
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.distributed = int(os.environ['WORLD_SIZE']) > 1
......
This diff is collapsed.
...@@ -137,26 +137,6 @@ class TestTensorCasts(unittest.TestCase): ...@@ -137,26 +137,6 @@ class TestTensorCasts(unittest.TestCase):
fn = lambda x: x.sum() fn = lambda x: x.sum()
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h)) run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
class TestDisabledCasts(unittest.TestCase):
def setUp(self):
self.handle = amp.init(enabled=False)
common_init(self)
def test_disabled_linear(self):
m = nn.Linear(self.h, self.h)
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
input_shape = (self.b, self.h)
for fn in [m, f]:
x = torch.randn(input_shape, dtype=torch.float).requires_grad_()
y = fn(x)
self.assertEqual(y.type(), FLOAT)
y.sum().backward()
self.assertEqual(x.grad.type(), FLOAT)
x = torch.randn(input_shape, dtype=torch.half).requires_grad_()
self.assertRaises(RuntimeError, fn, x)
# TODO: maybe more tests on disabled casting? # TODO: maybe more tests on disabled casting?
if __name__ == '__main__': if __name__ == '__main__':
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment