Commit 8a32e428 authored by Michael Carilli's avatar Michael Carilli
Browse files

Merging in master

parents d9c887c2 18f2eaee
...@@ -80,12 +80,12 @@ CUDA and C++ extensions via ...@@ -80,12 +80,12 @@ CUDA and C++ extensions via
``` ```
$ git clone https://github.com/NVIDIA/apex $ git clone https://github.com/NVIDIA/apex
$ cd apex $ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
``` ```
Apex also supports a Python-only build (required with Pytorch 0.4) via Apex also supports a Python-only build (required with Pytorch 0.4) via
``` ```
$ pip install -v --no-cache-dir . $ pip install -v --no-cache-dir ./
``` ```
A Python-only build omits: A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
......
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
from . import parallel from . import parallel
from . import amp from . import amp
from . import fp16_utils from . import fp16_utils
......
# amp: Automatic Mixed Precision # amp: Automatic Mixed Precision
## This README documents the deprecated (pre-unified) API.
## Documentation for the current unified API can be found [here](https://nvidia.github.io/apex/)
amp is an experimental tool to enable mixed precision training in
PyTorch with extreme simplicity and overall numerical safety. It
does so by employing a whitelist / blacklist model:
- Any function on the whitelist casts its input arguments to
fp16. These are functions like `torch.conv2d` that can take
advantage of TensorCore execution.
- Any function on the blacklist casts its input arguments to
fp32. These are functions like `torch.exp` or loss functions that
have trouble with the numerical properties of fp16.
- Any other function passes along its input types to its outputs. Care
is taken so that multi-argument functions or methods
(e.g. `torch.tensor.__add__`) can handle mixed type inputs. They
simply promote all inputs to have the widest type of any input.
The PyTorch hooks that enable the necessary casts are at the low-level
functional interface to PyTorch, so even custom layers will work with
amp, so long as they are built out of PyTorch functions and methods.
In particular, amp hooks into all of the following:
- Functions in the top-level `torch` namespace
- Functions in the `torch.nn.functional` namespace
- Methods on `Tensor` objects (GPU only, fp16 and fp32)
- Custom support for RNNs, even though they have no direct functional
interface:
- Recurrent cells: `torch.nn.{RNNCell, LSTMCell, GRUCell}`
- Recurrent layers: `torch.nn.{RNN, LSTM, GRU}`
In a few limited cases, amp needs help finding custom user-defined
functions that use low-level PyTorch features. In those cases, a
simple annotation is sufficient; this is described below.
## Installation and Requirements
amp is developed on Python 3.6 and PyTorch 0.4. It takes care to be
backwards-compatible with PyTorch 0.3, but users are _highly_
encouraged to upgrade.
amp is installed during normal apex installation, so refer to the
top-level README for more on installation.
## Usage and Getting Started
In the common case, using amp requires adding two lines of code (and
an import). The first enables amp, so that it can hook into all the
relevant PyTorch functions. The second tells it where backpropagation
occurs so that it can properly scale the loss and clear internal
per-iteration state.
#### 1. Enable amp
```python
from apex import amp
amp_handle = amp.init()
```
`amp.init()` takes three (optional) arguments. The most useful is
`enabled` (default=True), which simplifies command-line arguments. If
False, then everything amp does will be a zero-overhead pass-through
-- i.e., your code will run as-is.
For the other two options, the defaults are _highly_ recommended. The
first, `enable_caching` (default=True), indicates whether amp should
cache fp16 casts of model parameters on a per-iteration basis. This
prevents things like RNN cells used inside a loop from casting their
weight matrices over and over. The second, `verbose` (default=False)
toggles whether to print out every cast that occurs. Useful for
debugging, mostly.
#### 2. Wrap backpropagation
Nearly all PyTorch training scripts have a loop that looks like:
```python
# ... do a bunch of stuff to compute a loss
loss.backward()
optimizer.step()
# ...finish the iteration
```
To use amp, you need only tell it where backprop occurs:
```python
# ... same as before
with amp_handle.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# ... same as before
```
This context manager allows amp to:
1. Use automatic loss scaling to best use fp16 range
2. Clear its cache of casted parameters before the next optimizer step
Note that it is _possible_ to use amp without step 2. In which case,
you will not get automatic loss scaling, nor is it safe to
`enable_caching`. (Power user note: you can manually clear the cache
after each optimizer step with `amp_handle._clear_cache()`.)
## Multiple Optimizers or Backward Passes
Step (2) from the previous section works when you have one PyTorch
optimizer and a single `loss.backward()` for each iteration. Some
models are more complex with:
- Multiple optimizer objects (over different parameters)
- Multiple backward passes for each iteration, taking advantage of
PyTorch's gradient accumulation
To work with such models, amp requires you to explicitly wrap each
optimizer and indicate if it will have more than one backward pass
per-iteration.
#### Explicitly wrapping optimizers
If you have more than one optimizer, then you must explicitly wrap
each. (You can also do so with a single optimizer.) First, wrap the
optimizer after initializing amp:
```python
optimizer = # ... some optimizer
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer)
```
Second, use `optimizer.scale_loss(...)` to indicate where backprop
occurs:
```python
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
optimizer.step()
# ...
```
In essence, `amp_handle.scale_loss(loss, optimizer)` is syntactic
sugar for first wrapping the optimizer and then calling
`optimizer.scale_loss(loss)` in the single-optimizer case. But in the
multi-optimizer case, you must wrap each optimizer individually.
#### Handling multiple backward passes
PyTorch accumulates parameter gradients between calls to
`zero_grad()`, so it is possible to perform multiple backward passes
before making a parameter update:
```python
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
loss1.backward()
# ...
loss2 = ComputeLoss2(model)
loss2.backward()
# ...
optimizer.step() # has gradient contributions from both backward passes
```
The amp optimizer wrapper supports an additional argument `num_loss`
to work with code like this:
```python
amp_handle = amp.init()
optimizer = amp_handle.wrap_optimizer(optimizer, num_loss=2)
# ...
optimizer.zero_grad()
loss1 = ComputeLoss1(model)
with optimizer.scale_loss(loss1) as scaled_loss:
scaled_loss.backward()
# ...
loss2 = ComputeLoss2(model)
with optimizer.scale_loss(loss2) as scaled_loss:
scaled_loss.backward()
# ...
optimizer.step()
```
## Annotating User Functions ## Annotating User Functions
Nearly all PyTorch user code needs nothing more than the two steps Nearly all PyTorch user code needs nothing more than the two steps
...@@ -238,7 +61,7 @@ registration: ...@@ -238,7 +61,7 @@ registration:
When using this API, `module` is the containing class or module for When using this API, `module` is the containing class or module for
the function, and `function_name` is the _string_ name of the the function, and `function_name` is the _string_ name of the
function. Note that the function must be registered before the call to function. Note that the function must be registered before the call to
`amp.init()`. `amp.initalize()`.
For our FRU unit, we can register the backend function directly: For our FRU unit, we can register the backend function directly:
...@@ -246,5 +69,4 @@ For our FRU unit, we can register the backend function directly: ...@@ -246,5 +69,4 @@ For our FRU unit, we can register the backend function directly:
import backend import backend
amp.register_half_function(backend, 'FRUBackend') amp.register_half_function(backend, 'FRUBackend')
amp.init()
``` ```
...@@ -40,12 +40,12 @@ def applier(value, fn): ...@@ -40,12 +40,12 @@ def applier(value, fn):
return value return value
elif isinstance(value, np.ndarray): elif isinstance(value, np.ndarray):
return value return value
elif hasattr(value, "to"): # Allow handling of custom batch classes
return fn(value)
elif isinstance(value, container_abcs.Mapping): elif isinstance(value, container_abcs.Mapping):
return {applier(k, fn) : applier(v, fn) for k, v in value.items()} return {applier(k, fn) : applier(v, fn) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable): elif isinstance(value, container_abcs.Iterable):
return type(value)(applier(v, fn) for v in value) return type(value)(applier(v, fn) for v in value)
elif hasattr(value, "to"): # Allow handling of custom batch classes
return fn(value)
else: else:
# Do I want this to fire off even if someone chooses to pass something ordinary like # Do I want this to fire off even if someone chooses to pass something ordinary like
# an int or float? May be more annoying than it's worth. # an int or float? May be more annoying than it's worth.
...@@ -89,7 +89,16 @@ def check_params_fp32(models): ...@@ -89,7 +89,16 @@ def check_params_fp32(models):
"you chose. Use model.to('cuda') to use the default device.".format( "you chose. Use model.to('cuda') to use the default device.".format(
name, param.type())) name, param.type()))
for name, buf in model.named_buffers(): # Backward compatibility for PyTorch 0.4
if hasattr(model, 'named_buffers'):
buf_iter = model.named_buffers()
else:
buf_iter = model._buffers
for obj in buf_iter:
if type(obj)==tuple:
name, buf = obj
else:
name, buf = obj, buf_iter[obj]
if buf.is_floating_point(): if buf.is_floating_point():
if 'Half' in buf.type(): if 'Half' in buf.type():
warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n" warn_or_err("Found buffer {} with type {}, expected torch.cuda.FloatTensor.\n"
...@@ -201,7 +210,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs ...@@ -201,7 +210,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
_amp_state.loss_scalers = [] _amp_state.loss_scalers = []
for _ in range(num_losses): for _ in range(num_losses):
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale)) _amp_state.loss_scalers.append(LossScaler(properties.loss_scale,
min_loss_scale=_amp_state.min_loss_scale,
max_loss_scale=_amp_state.max_loss_scale))
if properties.patch_torch_functions: if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway. # handle is unused here. It's accessible later through a global value anyway.
......
...@@ -401,7 +401,9 @@ def _process_optimizer(optimizer, properties): ...@@ -401,7 +401,9 @@ def _process_optimizer(optimizer, properties):
_master_params_to_model_params, optimizer) _master_params_to_model_params, optimizer)
old_step = optimizer.step old_step = optimizer.step
def new_step(self): def new_step(self, closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
retval = old_step() retval = old_step()
if not (isinstance(self, FusedAdam) or isinstance(self, FusedSGD)): if not (isinstance(self, FusedAdam) or isinstance(self, FusedSGD)):
self._master_params_to_model_params() self._master_params_to_model_params()
...@@ -470,6 +472,11 @@ def _process_optimizer(optimizer, properties): ...@@ -470,6 +472,11 @@ def _process_optimizer(optimizer, properties):
def new_add_param_group(self, new_group): def new_add_param_group(self, new_group):
stash = self._amp_stash stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
assert isinstance(new_group, dict), "param group must be a dict" assert isinstance(new_group, dict), "param group must be a dict"
new_params = new_group['params'] new_params = new_group['params']
......
...@@ -195,7 +195,7 @@ def initialize( ...@@ -195,7 +195,7 @@ def initialize(
models, models,
optimizers=None, optimizers=None,
enabled=True, enabled=True,
opt_level=None, opt_level="O1",
cast_model_type=None, cast_model_type=None,
patch_torch_functions=None, patch_torch_functions=None,
keep_batchnorm_fp32=None, keep_batchnorm_fp32=None,
...@@ -204,6 +204,8 @@ def initialize( ...@@ -204,6 +204,8 @@ def initialize(
cast_model_outputs=None, cast_model_outputs=None,
num_losses=1, num_losses=1,
verbosity=1, verbosity=1,
min_loss_scale=None,
max_loss_scale=2.**24
): ):
""" """
Initialize your models, optimizers, and the Torch tensor and functional namespace according to the Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
...@@ -231,7 +233,7 @@ def initialize( ...@@ -231,7 +233,7 @@ def initialize(
REQUIRED for training, optional for inference. REQUIRED for training, optional for inference.
enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script enabled (bool, optional, default=True): If False, renders all Amp calls no-ops, so your script
should run as if Amp were not present. should run as if Amp were not present.
opt_level (str, required): Pure or mixed precision optimization level. Accepted values are opt_level (str, optional, default="O1"): Pure or mixed precision optimization level. Accepted values are
"O0", "O1", "O2", and "O3", explained in detail above. "O0", "O1", "O2", and "O3", explained in detail above.
cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see cast_model_type (``torch.dtype``, optional, default=None): Optional property override, see
above. above.
...@@ -251,6 +253,11 @@ def initialize( ...@@ -251,6 +253,11 @@ def initialize(
support multiple losses/backward passes, but use a single global loss scale support multiple losses/backward passes, but use a single global loss scale
for all of them. for all of them.
verbosity (int, default=1): Set to 0 to suppress Amp-related output. verbosity (int, default=1): Set to 0 to suppress Amp-related output.
min_loss_scale (float, default=None): Sets a floor for the loss scale values that can be chosen by dynamic
loss scaling. The default value of None means that no floor is imposed.
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
max_loss_scale (float, default=2.**24): Sets a ceiling for the loss scale values that can be chosen by
dynamic loss scaling. If dynamic loss scaling is not used, `max_loss_scale` is ignored.
Returns: Returns:
Model(s) and optimizer(s) modified according to the ``opt_level``. Model(s) and optimizer(s) modified according to the ``opt_level``.
...@@ -301,7 +308,10 @@ def initialize( ...@@ -301,7 +308,10 @@ def initialize(
_amp_state.verbosity = verbosity _amp_state.verbosity = verbosity
if not enabled: if not enabled:
return models, optimizers if optimizers is None:
return models
else:
return models, optimizers
if not torch.backends.cudnn.enabled: if not torch.backends.cudnn.enabled:
raise RuntimeError( raise RuntimeError(
...@@ -310,7 +320,8 @@ def initialize( ...@@ -310,7 +320,8 @@ def initialize(
if opt_level not in opt_levels: if opt_level not in opt_levels:
raise RuntimeError( raise RuntimeError(
"Unexpected optimization level {}. ".format(opt_level) + "Unexpected optimization level {}. ".format(opt_level) +
"Options are 'O0', 'O1', 'O2', 'O3'.") "Options are 'O0', 'O1', 'O2', 'O3'. Note that in `O0`, `O1`, etc., the prefix O is the letter O, " +
"not the number zero.")
else: else:
_amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties) _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True) maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
...@@ -318,6 +329,9 @@ def initialize( ...@@ -318,6 +329,9 @@ def initialize(
for k, v in _amp_state.opt_properties.options.items(): for k, v in _amp_state.opt_properties.options.items():
maybe_print("{:22} : {}".format(k, v), True) maybe_print("{:22} : {}".format(k, v), True)
_amp_state.min_loss_scale = min_loss_scale
_amp_state.max_loss_scale = max_loss_scale
maybe_print("Processing user overrides (additional kwargs that are not None)...", True) maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
# I chose to have the keyword arguments listed directly in the argument list, # I chose to have the keyword arguments listed directly in the argument list,
# instead of **kwargs, so I can't use kwargs.items() here. # instead of **kwargs, so I can't use kwargs.items() here.
......
...@@ -15,7 +15,8 @@ def scale_loss(loss, ...@@ -15,7 +15,8 @@ def scale_loss(loss,
optimizers, optimizers,
loss_id=0, loss_id=0,
model=None, model=None,
delay_unscale=False): delay_unscale=False,
delay_overflow_check=False):
""" """
On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``. On context manager entrance, creates ``scaled_loss = (loss.float())*current loss scale``.
``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``:: ``scaled_loss`` is yielded so that the user can call ``scaled_loss.backward()``::
...@@ -120,7 +121,7 @@ def scale_loss(loss, ...@@ -120,7 +121,7 @@ def scale_loss(loss,
optimizer._amp_stash.params_have_scaled_gradients = False optimizer._amp_stash.params_have_scaled_gradients = False
# For future fused optimizers that enable sync-free dynamic loss scaling, # For future fused optimizers that enable sync-free dynamic loss scaling,
# should_skip will always be False. # should_skip will always be False.
should_skip = loss_scaler.update_scale() should_skip = False if delay_overflow_check else loss_scaler.update_scale()
if should_skip: if should_skip:
for optimizer in optimizers: for optimizer in optimizers:
if not optimizer._amp_stash.already_patched: if not optimizer._amp_stash.already_patched:
...@@ -128,7 +129,9 @@ def scale_loss(loss, ...@@ -128,7 +129,9 @@ def scale_loss(loss,
# necessary because amp.scale_loss is already creating a temporary scope. # necessary because amp.scale_loss is already creating a temporary scope.
def patch_step(opt, loss_scaler, loss_id): def patch_step(opt, loss_scaler, loss_id):
opt_step = opt.step opt_step = opt.step
def skip_step(): def skip_step(closure=None):
if closure is not None:
raise RuntimeError("Currently, Amp does not support closure use with optimizers.")
maybe_print(("Gradient overflow. Skipping step, loss scaler " + maybe_print(("Gradient overflow. Skipping step, loss scaler " +
"{} reducing loss scale to {}").format(loss_id, "{} reducing loss scale to {}").format(loss_id,
loss_scaler.loss_scale())) loss_scaler.loss_scale()))
......
...@@ -28,8 +28,9 @@ FP16_FUNCS = [ ...@@ -28,8 +28,9 @@ FP16_FUNCS = [
FP32_FUNCS = [ FP32_FUNCS = [
# Interpolation/Upsampling # Interpolation/Upsampling TODO: Remove for 1.2
'interpolate', 'interpolate',
'grid_sample',
# Pointwise # Pointwise
'softplus', 'softplus',
......
...@@ -5,10 +5,10 @@ import importlib ...@@ -5,10 +5,10 @@ import importlib
import torch import torch
if compat.variable_is_tensor() and not compat.tensor_is_variable(): # if compat.variable_is_tensor() and not compat.tensor_is_variable():
MODULE = torch.Tensor MODULE = torch.Tensor
else: # else:
MODULE = torch.autograd.Variable # MODULE = torch.autograd.Variable
FP16_FUNCS = [ FP16_FUNCS = [
......
...@@ -49,7 +49,7 @@ FP32_FUNCS = [ ...@@ -49,7 +49,7 @@ FP32_FUNCS = [
'cumprod', 'cumprod',
'cumsum', 'cumsum',
'dist', 'dist',
'mean', # 'mean',
'norm', 'norm',
'prod', 'prod',
'std', 'std',
...@@ -60,6 +60,14 @@ FP32_FUNCS = [ ...@@ -60,6 +60,14 @@ FP32_FUNCS = [
'renorm' 'renorm'
] ]
version_strings = torch.__version__.split('.')
version_major = version_strings[0]
version_minor = version_strings[1]
version_num = float(version_major + "." + version_minor)
# Before torch 1.1, mean must be blacklisted.
if version_num < 1.1:
FP32_FUNCS.append('mean')
# Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We # Before CUDA 9.1, batched matmul was missing fast FP16 kernels. We
# check the CUDA version -- if at least 9.1, then put the bmm # check the CUDA version -- if at least 9.1, then put the bmm
# functions on the fp16 list. Otherwise, put them on the fp32 list. # functions on the fp16 list. Otherwise, put them on the fp32 list.
......
...@@ -39,14 +39,17 @@ class LossScaler(object): ...@@ -39,14 +39,17 @@ class LossScaler(object):
loss_scale, loss_scale,
init_scale=2.**16, init_scale=2.**16,
scale_factor=2., scale_factor=2.,
scale_window=2000): scale_window=2000,
min_loss_scale=None,
max_loss_scale=2.**24):
if loss_scale == "dynamic": if loss_scale == "dynamic":
self.dynamic = True self.dynamic = True
self._loss_scale = init_scale self._loss_scale = init_scale
else: else:
self.dynamic = False self.dynamic = False
self._loss_scale = loss_scale self._loss_scale = loss_scale
self._max_loss_scale = 2.**24 self._max_loss_scale = max_loss_scale
self._min_loss_scale = min_loss_scale
self._scale_seq_len = scale_window self._scale_seq_len = scale_window
self._unskipped = 0 self._unskipped = 0
self._has_overflow = False self._has_overflow = False
...@@ -198,14 +201,17 @@ class LossScaler(object): ...@@ -198,14 +201,17 @@ class LossScaler(object):
if self._has_overflow and self.dynamic: if self._has_overflow and self.dynamic:
should_skip = True should_skip = True
self._loss_scale /= 2. if(self._min_loss_scale):
self._loss_scale = max(self._min_loss_scale, self._loss_scale/2.)
else:
self._loss_scale = self._loss_scale/2.
self._unskipped = 0 self._unskipped = 0
else: else:
should_skip = False should_skip = False
self._unskipped += 1 self._unskipped += 1
if self._unskipped == self._scale_seq_len and self.dynamic: if self._unskipped == self._scale_seq_len and self.dynamic:
self._loss_scale = min(self._max_loss_scale, self._loss_scale * 2.) self._loss_scale = min(self._max_loss_scale, self._loss_scale*2.)
self._unskipped = 0 self._unskipped = 0
return should_skip return should_skip
...@@ -206,10 +206,6 @@ class DistributedDataParallel(Module): ...@@ -206,10 +206,6 @@ class DistributedDataParallel(Module):
if shared_param is not None: if shared_param is not None:
raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.") raise ValueError("shared_param is no longer supported as an option. It was misleadingly named from the start. It turns out overlapping communication with computation should work fine with shared parameters. If you still wish to delay communication to the end of the backward pass, use delay_allreduce=True|False instead.")
if gradient_average_split_factor is not None:
print("Warning: gradient_average_split_factor has been renamed to gradient_predivide_factor. For now, gradient_average_split_factor will also work, but please update to gradient_predivide_factor instead.")
self.gradient_predivide_factor = gradient_average_split_factor
self.world_size = float(dist.get_world_size()) self.world_size = float(dist.get_world_size())
self.retain_allreduce_buffers = retain_allreduce_buffers self.retain_allreduce_buffers = retain_allreduce_buffers
...@@ -234,6 +230,8 @@ class DistributedDataParallel(Module): ...@@ -234,6 +230,8 @@ class DistributedDataParallel(Module):
self.module = module self.module = module
self._disable_allreduce = False
if self._backend == self.backend_enum_holder.NCCL: if self._backend == self.backend_enum_holder.NCCL:
for param in self.module.parameters(): for param in self.module.parameters():
assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU."
...@@ -274,8 +272,14 @@ class DistributedDataParallel(Module): ...@@ -274,8 +272,14 @@ class DistributedDataParallel(Module):
del attrs['self.bucket_events'] del attrs['self.bucket_events']
return attrs return attrs
# Broadcast rank 0's bucket structure across all processes, and have all processes def enable_allreduce(self):
# regenerate their bucket structures to match. self._disable_allreduce = False
def disable_allreduce(self):
self._disable_allreduce = True
# Broadcast rank 0's bucket structure across all processes, and have all processes
# regenerate their bucket structures to match.
def sync_bucket_structure(self): def sync_bucket_structure(self):
# Append leftover buckets # Append leftover buckets
for tmp_bucket in self.tmp_buckets: for tmp_bucket in self.tmp_buckets:
...@@ -352,52 +356,53 @@ class DistributedDataParallel(Module): ...@@ -352,52 +356,53 @@ class DistributedDataParallel(Module):
grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused): def allreduce_hook(*unused):
if self.prof: if self.prof:
torch.cuda.nvtx.range_push("allreduce_hook") torch.cuda.nvtx.range_push("allreduce_hook")
if self.delay_allreduce or self.needs_refresh: if not self._disable_allreduce:
# TODO: How do we want to handle multiple backward passes between if self.delay_allreduce or self.needs_refresh:
# each forward, e.g., backward passes with retain_graph=True? # TODO: How do we want to handle multiple backward passes between
# needs_refresh and callback_queued are both vulnerable states. # each forward, e.g., backward passes with retain_graph=True?
if not self.delay_allreduce and self.needs_refresh: # needs_refresh and callback_queued are both vulnerable states.
# Use the backward pass to build the bucket structure on the fly. if not self.delay_allreduce and self.needs_refresh:
active_i = self.param_id_to_active_i[id(param)] # Use the backward pass to build the bucket structure on the fly.
active_i = self.param_id_to_active_i[id(param)]
# Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()] # Float, half, and double tensors are grouped into buckets separately.
current_type = self.param_type_to_tmp_i[param.type()]
self.tmp_buckets[current_type].append(active_i)
self.tmp_buckets[current_type].append(active_i)
ship_tmp_bucket = False
if self.custom_allreduce_triggers: ship_tmp_bucket = False
if id(param) in self.allreduce_trigger_params: if self.custom_allreduce_triggers:
ship_tmp_bucket = True if id(param) in self.allreduce_trigger_params:
else: ship_tmp_bucket = True
self.tmp_numels[current_type] += param.numel() else:
if self.tmp_numels[current_type] >= self.message_size: self.tmp_numels[current_type] += param.numel()
ship_tmp_bucket = True if self.tmp_numels[current_type] >= self.message_size:
ship_tmp_bucket = True
# To consider: If custom_allreduce_triggers are in use, ship all
# tmp_buckets, not just tmp_buckets[current_type]. # To consider: If custom_allreduce_triggers are in use, ship all
if ship_tmp_bucket: # tmp_buckets, not just tmp_buckets[current_type].
self.active_i_buckets.append(self.tmp_buckets[current_type]) if ship_tmp_bucket:
self.tmp_buckets[current_type] = [] self.active_i_buckets.append(self.tmp_buckets[current_type])
self.tmp_numels[current_type] = 0 self.tmp_buckets[current_type] = []
self.tmp_numels[current_type] = 0
if not self.callback_queued:
Variable._execution_engine.queue_callback(allreduce_params) if not self.callback_queued:
self.callback_queued = True Variable._execution_engine.queue_callback(allreduce_params)
else: self.callback_queued = True
if not self.callback_queued: else:
Variable._execution_engine.queue_callback(overlapping_backward_epilogue) if not self.callback_queued:
self.callback_queued = True Variable._execution_engine.queue_callback(overlapping_backward_epilogue)
self.callback_queued = True
self.comm_ready_buckets(param)
self.comm_ready_buckets(param)
if self.prof: if self.prof:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
grad_acc.register_hook(allreduce_hook) grad_acc.register_hook(allreduce_hook)
self.grad_accs.append(grad_acc) self.grad_accs.append(grad_acc)
...@@ -468,8 +473,6 @@ class DistributedDataParallel(Module): ...@@ -468,8 +473,6 @@ class DistributedDataParallel(Module):
# further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream. # further reuse by the main stream while the allreduce/div/unflatten are underway in bucket_stream.
tensor.record_stream(bucket_stream) tensor.record_stream(bucket_stream)
# torch.cuda.synchronize()
return tensor return tensor
...@@ -560,75 +563,76 @@ class DistributedDataParallel(Module): ...@@ -560,75 +563,76 @@ class DistributedDataParallel(Module):
if self.prof: if self.prof:
torch.cuda.nvtx.range_push("forward pass DDP logic") torch.cuda.nvtx.range_push("forward pass DDP logic")
if not self.delay_allreduce: if not self._disable_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad] if not self.delay_allreduce:
param_list = [param for param in self.module.parameters() if param.requires_grad]
# Conditions under which to refresh self.record
# Forward has the authority to set needs_refresh to True, but only allreduce_params # Conditions under which to refresh self.record
# in backward has the authority to set needs_refresh to False. # Forward has the authority to set needs_refresh to True, but only allreduce_params
# Parentheses are not necessary for correct order of operations, but make the intent clearer. # in backward has the authority to set needs_refresh to False.
if ((not self.active_params) or # Parentheses are not necessary for correct order of operations, but make the intent clearer.
(len(param_list) != len(self.active_params)) or if ((not self.active_params) or
any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])): (len(param_list) != len(self.active_params)) or
self.needs_refresh = True any([param1 is not param2 for param1, param2 in zip(param_list, self.active_params)])):
self.needs_refresh = True
if self.needs_refresh:
self.active_i_buckets = [] if self.needs_refresh:
self.buckets = [] self.active_i_buckets = []
self.tmp_buckets = [[], [], []] # [running half, float, double buckets] self.buckets = []
self.tmp_numels = [0, 0, 0] self.tmp_buckets = [[], [], []] # [running half, float, double buckets]
self.bucket_sizes = [] self.tmp_numels = [0, 0, 0]
self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)} self.bucket_sizes = []
self.param_id_to_bucket = {} self.param_id_to_active_i = {id(param) : i for i, param in enumerate(param_list)}
self.bucket_pgs = [] self.param_id_to_bucket = {}
self.bucket_streams = [] self.bucket_pgs = []
self.bucket_events = [] self.bucket_streams = []
else: self.bucket_events = []
# self.buckets = [[None for _ in range(self.bucket_sizes[i])]
# for i in range(self.num_buckets)]
if not self.buckets:
self.buckets = [[None for _ in range(self.bucket_sizes[i])]
for i in range(self.num_buckets)]
else:
assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
len(self.buckets), self.num_buckets)
for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_communicators:
self.bucket_pgs = self.allreduce_communicators[0]
self.bucket_streams = self.allreduce_communicators[1]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else: else:
if self.allreduce_different_streams: # self.buckets = [[None for _ in range(self.bucket_sizes[i])]
if not self.bucket_pgs: # for i in range(self.num_buckets)]
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)] if not self.buckets:
for i, bg in enumerate(self.bucket_pgs): self.buckets = [[None for _ in range(self.bucket_sizes[i])]
print("rank {} created group {} with backend {}".format( for i in range(self.num_buckets)]
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else: else:
if not self.bucket_streams: assert len(self.buckets) == self.num_buckets, "len(buckets) = {}, expected {}".format(
self.bucket_streams = [torch.cuda.Stream()] len(self.buckets), self.num_buckets)
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)] for b, bucket in enumerate(self.buckets):
assert len(bucket) == self.bucket_sizes[b], "len(buckets[{}]) = {}, expected {})".format(
b, len(buckets[b]), self.bucket_sizes[b])
for i in range(len(bucket)):
bucket[i] = None
if self.allreduce_communicators:
self.bucket_pgs = self.allreduce_communicators[0]
self.bucket_streams = self.allreduce_communicators[1]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if self.allreduce_different_streams:
if not self.bucket_pgs:
self.bucket_pgs = [dist.new_group() for _ in range(self.num_allreduce_streams)]
for i, bg in enumerate(self.bucket_pgs):
print("rank {} created group {} with backend {}".format(
dist.get_rank(), i, dist.get_backend(bg)))
if self.allreduce_different_streams:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream() for _ in range(self.num_allreduce_streams)]
self.bucket_events = [torch.cuda.Event(enable_timing=False,
blocking=False) for _ in range(self.num_allreduce_streams)]
else:
if not self.bucket_streams:
self.bucket_streams = [torch.cuda.Stream()]
self.bucket_events = [torch.cuda.Event(enable_timing=False, blocking=False)]
self.buckets_ready_size = [0 for i in range(self.num_buckets)] self.buckets_ready_size = [0 for i in range(self.num_buckets)]
if(self.retain_allreduce_buffers): if(self.retain_allreduce_buffers):
self.allreduce_buffers = [None for _ in range(self.num_buckets)] self.allreduce_buffers = [None for _ in range(self.num_buckets)]
self.next_bucket = 0 self.next_bucket = 0
self.ready_buckets_not_reduced = set() self.ready_buckets_not_reduced = set()
self.active_params = param_list self.active_params = param_list
self.callback_queued = False self.callback_queued = False
if self.prof: if self.prof:
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
......
...@@ -27,10 +27,31 @@ void multi_tensor_axpby_cuda( ...@@ -27,10 +27,31 @@ void multi_tensor_axpby_cuda(
float b, float b,
int arg_to_check); int arg_to_check);
at::Tensor multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists); std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python);
void multi_tensor_lamb_stage1_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_decay,
const int step,
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
const float max_global_grad_norm);
void multi_tensor_lamb_stage2_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float step_size);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda, m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
...@@ -41,4 +62,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -41,4 +62,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"out = a*x + b*y for a list of contiguous tensors"); "out = a*x + b*y for a list of contiguous tensors");
m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda, m.def("multi_tensor_l2norm", &multi_tensor_l2norm_cuda,
"Computes L2 norm for a list of contiguous tensors"); "Computes L2 norm for a list of contiguous tensors");
m.def("multi_tensor_lamb_stage1_cuda", &multi_tensor_lamb_stage1_cuda,
"Computes update part of LAMB optimizer");
m.def("multi_tensor_lamb_stage2_cuda", &multi_tensor_lamb_stage2_cuda,
"Completes application of gradient to parameters for LAMB optimizer");
} }
...@@ -20,6 +20,7 @@ template<int n> struct TensorListMetadata ...@@ -20,6 +20,7 @@ template<int n> struct TensorListMetadata
int sizes[depth_to_max_tensors[n-1]]; int sizes[depth_to_max_tensors[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]]; unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int. int block_to_chunk[depth_to_max_blocks[n-1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
}; };
...@@ -66,6 +67,7 @@ void multi_tensor_apply( ...@@ -66,6 +67,7 @@ void multi_tensor_apply(
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0; int loc_block_info = 0;
int loc_tensor_info = 0; int loc_tensor_info = 0;
for(int t = 0; t < ntensors; t++) for(int t = 0; t < ntensors; t++)
...@@ -106,6 +108,7 @@ void multi_tensor_apply( ...@@ -106,6 +108,7 @@ void multi_tensor_apply(
{ {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl; // std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info = 0; loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} }
else else
{ {
...@@ -114,6 +117,7 @@ void multi_tensor_apply( ...@@ -114,6 +117,7 @@ void multi_tensor_apply(
for(int d = 0; d < depth; d++) for(int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1]; tl.addresses[d][0] = tl.addresses[d][loc_tensor_info-1];
loc_tensor_info = 1; loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
} }
} }
} }
......
...@@ -16,11 +16,14 @@ ...@@ -16,11 +16,14 @@
template<typename x_t> template<typename x_t>
struct L2NormFunctor struct L2NormFunctor
{ {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int chunk_size,
volatile int* noop_gmem, volatile int* noop_gmem,
TensorListMetadata<1>& tl, TensorListMetadata<1>& tl,
float* output) float* output,
float* output_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{ {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
...@@ -35,47 +38,114 @@ struct L2NormFunctor ...@@ -35,47 +38,114 @@ struct L2NormFunctor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
__shared__ float vals[512]; __shared__ float s_vals[512];
// Non-divergent exit condition for __syncthreads, not necessary here float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
float val = 0; for(int i = 0; i < ILP; i++)
for(int i = threadIdx.x; i < n && i < chunk_size; i += blockDim.x) vals[i] = 0.f;
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{ {
float next = static_cast<float>(x[i]); #pragma unroll
val += next*next; for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] += next*next;
}
}
} }
float final = reduce_block_into_lanes(vals, val); float val = 0.f;
for(int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
if(threadIdx.x == 0) if(threadIdx.x == 0)
{ {
if(!isfinite(final)) if(!isfinite(final))
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. *noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final; output[blockIdx.x] += final;
if(per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
} }
} }
}; };
__global__ void cleanup(float* x, float* ret)
__global__ void cleanup(
float* output,
float* output_per_tensor,
float* ret,
float* ret_per_tensor,
bool per_tensor,
int max_chunks_per_tensor)
{ {
__shared__ float vals[512]; __shared__ float vals[512];
float val = 0; if(blockIdx.x == 0)
if(threadIdx.x < 320) {
val = x[threadIdx.x]; float val = 0;
if(threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0)
*ret = sqrt(final);
}
if(per_tensor)
{
float* output_this_tensor = output_per_tensor + blockIdx.x*max_chunks_per_tensor;
float val = 0;
for(int i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x)
val += output_this_tensor[i];
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if(threadIdx.x == 0) if(threadIdx.x == 0)
*ret = sqrt(final); ret_per_tensor[blockIdx.x] = sqrt(final);
}
} }
at::Tensor multi_tensor_l2norm_cuda(
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, int chunk_size,
at::Tensor noop_flag, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists) std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python)
{ {
auto output = at::zeros({320}, tensor_lists[0][0].options().dtype(at::kFloat)); bool per_tensor = per_tensor_python.has_value() ? per_tensor_python.value() : false;
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
auto output = at::zeros({320}, float_options);
at::Tensor output_per_tensor;
at::Tensor ret_per_tensor;
int ntensors = tensor_lists[0].size();
int max_chunks_per_tensor = -1;
if(per_tensor)
{
for(int t = 0; t < ntensors; t++)
{
int max_chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1)/chunk_size;
if(max_chunks_this_tensor > max_chunks_per_tensor)
max_chunks_per_tensor = max_chunks_this_tensor;
}
output_per_tensor = at::zeros({ntensors*max_chunks_per_tensor}, float_options);
ret_per_tensor = at::empty({ntensors}, float_options);
}
else
{
ret_per_tensor = at::empty({0}, float_options);
}
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda", DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "multi_tensor_l2norm_cuda",
multi_tensor_apply<1>( multi_tensor_apply<1>(
...@@ -84,7 +154,10 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -84,7 +154,10 @@ at::Tensor multi_tensor_l2norm_cuda(
noop_flag, noop_flag,
tensor_lists, tensor_lists,
L2NormFunctor<scalar_t_0>(), L2NormFunctor<scalar_t_0>(),
output.data<float>());) output.data<float>(),
per_tensor ? output_per_tensor.data<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);)
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
...@@ -95,6 +168,13 @@ at::Tensor multi_tensor_l2norm_cuda( ...@@ -95,6 +168,13 @@ at::Tensor multi_tensor_l2norm_cuda(
// logic, but keeping it simple for now // logic, but keeping it simple for now
auto ret = at::empty({1}, output.options()); auto ret = at::empty({1}, output.options());
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
cleanup<<<1, 512, 0, stream>>>(output.data<float>(), ret.data<float>()); cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
return ret; output.data<float>(),
per_tensor ? output_per_tensor.data<float>() : nullptr,
ret.data<float>(),
per_tensor ? ret_per_tensor.data<float>() : nullptr,
per_tensor,
max_chunks_per_tensor);
return std::tuple<at::Tensor, at::Tensor>(ret, ret_per_tensor);
} }
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
// Step 1 computes the 'update' value of regular Adam optimizer.
template<typename GRAD_T, typename T>
struct LAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<5>& tl,
const float* per_tensor_decay,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float clipped_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float decay = per_tensor_decay[tensor_num];
GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
T* update = (T*)tl.addresses[4][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
GRAD_T r_g[ILP];
T r_p[ILP];
T r_m[ILP];
T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = GRAD_T(0);
r_p[ii] = T(0);
r_m[ii] = T(0);
r_v[ii] = T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + (1-beta1) * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
T next_m_unbiased = r_m[ii] / beta1_correction;
T next_v_unbiased = r_v[ii] / beta2_correction;
T denom = std::sqrt(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
update[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_lamb_stage1_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_decay,
const int step,
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
const float max_global_grad_norm)
{
using namespace at;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step);
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 1, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage1Functor<scalar_t_0, scalar_t_1>(),
per_tensor_decay.data<float>(),
beta1,
beta2,
beta1_correction,
beta2_correction,
epsilon,
clipped_global_grad_norm); ))
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T>
struct LAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* update = (T*)tl.addresses[1][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
T r_p[ILP];
T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio*r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
}
}
}
};
void multi_tensor_lamb_stage2_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
multi_tensor_apply<2>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
LAMBStage2Functor<scalar_t_0>(),
per_tensor_param_norm.data<float>(),
per_tensor_update_norm.data<float>(),
learning_rate); )
AT_CUDA_CHECK(cudaGetLastError());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
...@@ -94,7 +94,7 @@ receive gradients. ...@@ -94,7 +94,7 @@ receive gradients.
If, for a given backward pass, there's only one optimizer whose params are about to receive gradients, If, for a given backward pass, there's only one optimizer whose params are about to receive gradients,
you may pass that optimizer directly to ``amp.scale_loss``. Otherwise, you must pass the you may pass that optimizer directly to ``amp.scale_loss``. Otherwise, you must pass the
list of optimizers whose params are about to receive gradients:: list of optimizers whose params are about to receive gradients. Example with 3 losses and 2 optimizers::
# loss0 accumulates gradients only into params owned by optim0: # loss0 accumulates gradients only into params owned by optim0:
with amp.scale_loss(loss0, optim0) as scaled_loss: with amp.scale_loss(loss0, optim0) as scaled_loss:
...@@ -145,18 +145,20 @@ Gradient accumulation across iterations ...@@ -145,18 +145,20 @@ Gradient accumulation across iterations
The following should "just work," and properly accommodate multiple models/optimizers/losses, as well as The following should "just work," and properly accommodate multiple models/optimizers/losses, as well as
gradient clipping via the `instructions above`_:: gradient clipping via the `instructions above`_::
# If your intent is to simulate a larger batch size using gradient accumulation,
# you can divide the loss by the number of accumulation iterations (so that gradients
# will be averaged over that many iterations):
loss = loss/iters_to_accumulate
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# Every iters_to_accumulate iterations, call step() and reset gradients:
if iter%iters_to_accumulate == 0: if iter%iters_to_accumulate == 0:
# Every iters_to_accumulate iterations, unscale and step
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
# Gradient clipping if desired: # Gradient clipping if desired:
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm) # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
else:
# Otherwise, accumulate gradients, don't unscale or step.
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
As a minor performance optimization, you can pass ``delay_unscale=True`` As a minor performance optimization, you can pass ``delay_unscale=True``
to ``amp.scale_loss`` until you're ready to ``step()``. You should only attempt ``delay_unscale=True`` to ``amp.scale_loss`` until you're ready to ``step()``. You should only attempt ``delay_unscale=True``
......
...@@ -173,3 +173,9 @@ Running with the `--deterministic` flag should produce bitwise identical outputs ...@@ -173,3 +173,9 @@ Running with the `--deterministic` flag should produce bitwise identical outputs
regardless of what other options are used (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)). regardless of what other options are used (see [Pytorch docs on reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)).
Since `--deterministic` disables `torch.backends.cudnn.benchmark`, `--deterministic` may Since `--deterministic` disables `torch.backends.cudnn.benchmark`, `--deterministic` may
cause a modest performance decrease. cause a modest performance decrease.
## Profiling
If you're curious how the network actually looks on the CPU and GPU timelines (for example, how good is the overall utilization?
Is the prefetcher really overlapping data transfers?) try profiling `main_amp.py`.
[Detailed instructions can be found here](https://gist.github.com/mcarilli/213a4e698e4a0ae2234ddee56f4f3f95).
...@@ -60,7 +60,7 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', ...@@ -60,7 +60,7 @@ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--prof', dest='prof', action='store_true', parser.add_argument('--prof', default=-1, type=int,
help='Only run 10 iterations for profiling.') help='Only run 10 iterations for profiling.')
parser.add_argument('--deterministic', action='store_true') parser.add_argument('--deterministic', action='store_true')
...@@ -236,8 +236,7 @@ def main(): ...@@ -236,8 +236,7 @@ def main():
# train for one epoch # train for one epoch
train(train_loader, model, criterion, optimizer, epoch) train(train_loader, model, criterion, optimizer, epoch)
if args.prof:
break
# evaluate on validation set # evaluate on validation set
prec1 = validate(val_loader, model, criterion) prec1 = validate(val_loader, model, criterion)
...@@ -272,9 +271,23 @@ class data_prefetcher(): ...@@ -272,9 +271,23 @@ class data_prefetcher():
self.next_input = None self.next_input = None
self.next_target = None self.next_target = None
return return
# if record_stream() doesn't work, another option is to make sure device inputs are created
# on the main stream.
# self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
# self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
# Need to make sure the memory allocated for next_* is not still in use by the main stream
# at the time we start copying to next_*:
# self.stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.stream): with torch.cuda.stream(self.stream):
self.next_input = self.next_input.cuda(non_blocking=True) self.next_input = self.next_input.cuda(non_blocking=True)
self.next_target = self.next_target.cuda(non_blocking=True) self.next_target = self.next_target.cuda(non_blocking=True)
# more code for the alternative if record_stream() doesn't work:
# copy_ will record the use of the pinned source tensor in this side stream.
# self.next_input_gpu.copy_(self.next_input, non_blocking=True)
# self.next_target_gpu.copy_(self.next_target, non_blocking=True)
# self.next_input = self.next_input_gpu
# self.next_target = self.next_target_gpu
# With Amp, it isn't necessary to manually convert data to half. # With Amp, it isn't necessary to manually convert data to half.
# if args.fp16: # if args.fp16:
# self.next_input = self.next_input.half() # self.next_input = self.next_input.half()
...@@ -286,6 +299,10 @@ class data_prefetcher(): ...@@ -286,6 +299,10 @@ class data_prefetcher():
torch.cuda.current_stream().wait_stream(self.stream) torch.cuda.current_stream().wait_stream(self.stream)
input = self.next_input input = self.next_input
target = self.next_target target = self.next_target
if input is not None:
input.record_stream(torch.cuda.current_stream())
if target is not None:
target.record_stream(torch.cuda.current_stream())
self.preload() self.preload()
return input, target return input, target
...@@ -305,33 +322,34 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -305,33 +322,34 @@ def train(train_loader, model, criterion, optimizer, epoch):
i = 0 i = 0
while input is not None: while input is not None:
i += 1 i += 1
if args.prof >= 0 and i == args.prof:
print("Profiling begun at iteration {}".format(i))
torch.cuda.cudart().cudaProfilerStart()
adjust_learning_rate(optimizer, epoch, i, len(train_loader)) if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
if args.prof: adjust_learning_rate(optimizer, epoch, i, len(train_loader))
if i > 10:
break
# compute output # compute output
if args.prof: torch.cuda.nvtx.range_push("forward") if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
output = model(input) output = model(input)
if args.prof: torch.cuda.nvtx.range_pop() if args.prof >= 0: torch.cuda.nvtx.range_pop()
loss = criterion(output, target) loss = criterion(output, target)
# compute gradient and do SGD step # compute gradient and do SGD step
optimizer.zero_grad() optimizer.zero_grad()
if args.prof: torch.cuda.nvtx.range_push("backward") if args.prof >= 0: torch.cuda.nvtx.range_push("backward")
with amp.scale_loss(loss, optimizer) as scaled_loss: with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
if args.prof: torch.cuda.nvtx.range_pop() if args.prof >= 0: torch.cuda.nvtx.range_pop()
# for param in model.parameters(): # for param in model.parameters():
# print(param.data.double().sum().item(), param.grad.data.double().sum().item()) # print(param.data.double().sum().item(), param.grad.data.double().sum().item())
if args.prof: torch.cuda.nvtx.range_push("step") if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")
optimizer.step() optimizer.step()
if args.prof: torch.cuda.nvtx.range_pop() if args.prof >= 0: torch.cuda.nvtx.range_pop()
if i%args.print_freq == 0: if i%args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed. # Every print_freq iterations, check the loss, accuracy, and speed.
...@@ -370,8 +388,17 @@ def train(train_loader, model, criterion, optimizer, epoch): ...@@ -370,8 +388,17 @@ def train(train_loader, model, criterion, optimizer, epoch):
args.world_size*args.batch_size/batch_time.avg, args.world_size*args.batch_size/batch_time.avg,
batch_time=batch_time, batch_time=batch_time,
loss=losses, top1=top1, top5=top5)) loss=losses, top1=top1, top5=top5))
if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()")
input, target = prefetcher.next() input, target = prefetcher.next()
if args.prof >= 0: torch.cuda.nvtx.range_pop()
# Pop range "Body of iteration {}".format(i)
if args.prof >= 0: torch.cuda.nvtx.range_pop()
if args.prof >= 0 and i == args.prof + 10:
print("Profiling ended at iteration {}".format(i))
torch.cuda.cudart().cudaProfilerStop()
quit()
def validate(val_loader, model, criterion): def validate(val_loader, model, criterion):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment