Commit 6af5980e authored by Michael Carilli's avatar Michael Carilli
Browse files

Merging in FusedAdam treatment

parents 16a3bdf3 7aad54f7
# Introduction
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
This repository holds NVIDIA-maintained utilities to streamline
mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intention of Apex is to make up-to-date utilities available to
The intention of Apex is to make up-to-date utilities available to
users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
......@@ -29,7 +29,7 @@ different flags to `amp.initialize`.
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
......
......@@ -114,29 +114,13 @@ def check_optimizers(optimizers):
raise RuntimeError("An incoming optimizer is an instance of {}. ".format(optim_type) +
"The optimizer(s) passed to amp.initialize() must be bare \n"
"instances of either ordinary Pytorch optimizers, or Apex fused \n"
"optimizers (currently just FusedAdam, but FusedSGD will be added \n"
"soon). You should not manually wrap your optimizer in either \n"
"optimizers (FusedAdam or FusedSGD). \n"
"You should not manually wrap your optimizer in either \n"
"apex.fp16_utils.FP16_Optimizer or apex.optimizers.FP16_Optimizer. \n"
"amp.initialize will take care of that for you (if necessary) based \n"
"on the specified opt_level (and optional overridden properties).")
def wrap_fused_adam(optimizer, properties):
msg = 'Currently, the usage of FusedAdam is restricted to '\
'amp.initialize(..., opt_level="O2", keep_batchnorm_fp32=False, '\
'loss_scale=float or "dynamic"). We are working on enabling more general usage.'
assert properties.master_weights is True, msg
assert properties.cast_model_type is torch.float16, msg
assert (properties.keep_batchnorm_fp32 is False or
properties.keep_batchnorm_fp32 is None), msg
if properties.loss_scale == "dynamic":
return FP16_Optimizer_for_fused(optimizer, dynamic_loss_scale=True)
else:
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)
def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs=None):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init
......@@ -163,7 +147,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)
check_optimizers(optimizers)
# In the future, when FP16_Optimizer can be deprecated and master weights can
......@@ -196,7 +180,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model.forward = patch_forward(model.forward)
# State dict trick to recast any preexisting per-param state tensors
# State dict trick to recast any preexisting per-param state tensors
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())
elif cast_model_outputs is not None:
......@@ -212,11 +196,7 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
model.forward = patch_forward(model.forward)
for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
else:
optimizers[i] = _process_optimizer(optimizer, properties)
optimizers[i] = _process_optimizer(optimizer, properties)
_amp_state.loss_scalers = []
for _ in range(num_losses):
......
......@@ -3,6 +3,7 @@ from ..fp16_utils import master_params_to_model_params
from ..multi_tensor_apply import multi_tensor_applier
from ._amp_state import maybe_print
import torch
from ..optimizers import FusedAdam
class AmpOptimizerState(object):
......@@ -73,6 +74,40 @@ def lazy_init_with_master_weights(self):
self.load_state_dict(self.state_dict())
def post_backward_models_are_masters(scaler, params, stashed_grads):
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def prepare_backward_with_master_weights(self):
stash = self._amp_stash
......@@ -106,7 +141,7 @@ def post_backward_with_master_weights(self, scaler):
if fp16_param.grad is None and fp32_param.grad is not None:
continue
elif fp16_param.grad is not None and fp32_param.grad is None:
fp32_param.grad = torch.empty_like(fp32_param)
fp32_param.grad = torch.empty_like(fp32_param)
fp16_grads_needing_unscale.append(fp16_param.grad)
new_fp32_grads.append(fp32_param.grad)
elif fp16_param.grad is not None and fp32_param.grad is not None:
......@@ -129,37 +164,10 @@ def post_backward_with_master_weights(self, scaler):
preexisting_fp32_grads)
# fp32 params can be treated as they would be in the "no_master_weights" case.
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(stash.all_fp32_from_fp32_params,
stash.all_fp32_from_fp32_grad_stash):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None:
continue
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
# Clear the stash.
for i in range(len(stash.all_fp32_from_fp32_grad_stash)):
stash.all_fp32_from_fp32_grad_stash[i] = None
post_backward_models_are_masters(
scaler,
stash.all_fp32_from_fp32_params,
stash.all_fp32_from_fp32_grad_stash)
def lazy_init_no_master_weights(self):
......@@ -176,7 +184,7 @@ def lazy_init_no_master_weights(self):
raise TypeError("Optimizer's parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
stash.all_fp16_grad_stash = [None for _ in stash.all_fp16_params]
stash.all_fp32_grad_stash = [None for _ in stash.all_fp32_params]
......@@ -206,37 +214,56 @@ def post_backward_no_master_weights(self, scaler):
(stash.all_fp32_params, stash.all_fp32_grad_stash))
for params, stashed_grads in split_types:
# This is a lot of python overhead...
grads_needing_unscale = []
grads_needing_unscale_with_stash = []
stashed = []
for param, stashed_grad in zip(params, stashed_grads):
if param.grad is None and stashed_grad is not None:
param.grad = stashed_grad
elif param.grad is not None and stashed_grad is None:
grads_needing_unscale.append(param.grad)
elif param.grad is not None and stashed_grad is not None:
grads_needing_unscale_with_stash.append(param.grad)
stashed.append(stashed_grad)
else: # param.grad is None and stashed_grad is None
continue
post_backward_models_are_masters(scaler, params, stashed_grads)
if len(grads_needing_unscale) > 0:
scaler.unscale(
grads_needing_unscale,
grads_needing_unscale,
scaler.loss_scale(),
models_are_masters=True)
if len(grads_needing_unscale_with_stash) > 0:
scaler.unscale_with_stashed(
grads_needing_unscale_with_stash,
stashed,
grads_needing_unscale_with_stash)
def prepare_backward_with_master_weights_fused(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
def post_backward_with_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = [[param.grad.data for param in group] for group in stash.fp16_groups]
stash.output_params = [[param for param in group] for group in stash.fp16_groups]
norm_groups = []
skip = False
for grad_group in stash.grads:
norm = multi_tensor_applier(
stash.multi_tensor_l2norm,
stash.dummy_overflow_buf,
[grad_group])
# Still syncing here for now.
norm = float(norm)
norm_groups.append(norm)
if norm == float('inf') or norm == -float('inf') or norm != norm:
skip = True
if skip:
scaler._overflow_buf.fill_(1.)
scaler._has_overflow = True
stash.grad_norms = norm_groups
def prepare_backward_no_master_weights_fused(self):
stash = self._amp_stash
if not stash.lazy_init_called:
self._lazy_init_maybe_master_weights()
stash.lazy_init_called = True
def post_backward_no_master_weights_fused(self, scaler):
stash = self._amp_stash
stash.scale = scaler.loss_scale()
stash.grads = None
stash.output_params = None
stash.grad_norms = None
def _master_params_to_model_params(self):
......@@ -274,6 +301,7 @@ def _process_optimizer(optimizer, properties):
if multi_tensor_applier.available:
import amp_C
optimizer._amp_stash.multi_tensor_scale = amp_C.multi_tensor_scale
optimizer._amp_stash.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
optimizer._amp_stash.dummy_overflow_buf = torch.cuda.IntTensor([0]);
if properties.master_weights:
......@@ -286,7 +314,8 @@ def _process_optimizer(optimizer, properties):
old_step = optimizer.step
def new_step(self):
retval = old_step()
self._master_params_to_model_params()
if not isinstance(self, FusedAdam):
self._master_params_to_model_params()
# Clear the master grads that wouldn't be zeroed by model.zero_grad()
for param in self._amp_stash.all_fp32_from_fp16_params:
param.grad = None
......@@ -313,19 +342,29 @@ def _process_optimizer(optimizer, properties):
param.grad = None
optimizer.zero_grad = types.MethodType(new_zero_grad, optimizer)
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer)
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights_fused, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights_fused, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_with_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_with_master_weights, optimizer)
else:
optimizer._lazy_init_maybe_master_weights = types.MethodType(
lazy_init_no_master_weights, optimizer)
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer)
if isinstance(optimizer, FusedAdam):
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights_fused, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights_fused, optimizer)
else:
optimizer._prepare_amp_backward = types.MethodType(
prepare_backward_no_master_weights, optimizer)
optimizer._post_amp_backward = types.MethodType(
post_backward_no_master_weights, optimizer)
return optimizer
......@@ -6,8 +6,6 @@ from . import utils
from .opt import OptimWrapper
from .scaler import LossScaler
from ._amp_state import _amp_state, master_params, maybe_print
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
# There's no reason to expose the notion of a "handle". Everything can happen through amp.* calls.
......@@ -82,13 +80,8 @@ def scale_loss(loss,
if isinstance(optimizers, torch.optim.Optimizer):
optimizers = [optimizers]
# this is what happens when i have to support tools from different sources under the same API...
# TODO: Rewrite FusedAdam to use multi-tensor apply and the same loss scaler.
if isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scale = optimizers.cur_scale
else:
loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale()
loss_scaler = _amp_state.loss_scalers[loss_id]
loss_scale = loss_scaler.loss_scale()
if ((not _amp_state.opt_properties.master_weights)
and (not loss_scaler.dynamic)
......@@ -113,8 +106,8 @@ def scale_loss(loss,
for optimizer in optimizers:
optimizer._amp_stash.params_have_scaled_gradients = True
else:
# FusedAdam and FusedSGD will take care of unscaling as part of their step() methods.
if not isinstance(optimizers, FP16_Optimizer_for_fused):
# FusedAdam and FusedSGD may take care of unscaling as part of their step() methods.
# if not isinstance(optimizers, FP16_Optimizer_for_fused):
loss_scaler.clear_overflow_state()
for optimizer in optimizers:
optimizer._post_amp_backward(loss_scaler)
......
......@@ -2,6 +2,8 @@ import types
import torch
import importlib
from ..multi_tensor_apply import multi_tensor_applier
class FusedAdam(torch.optim.Optimizer):
"""Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via
......@@ -25,6 +27,8 @@ class FusedAdam(torch.optim.Optimizer):
adds eps to the bias-corrected second moment estimate before
evaluating square root instead of adding it to the square root of
second moment estimate as in the original paper. (default: False)
use_mt (boolean, optional): use multi tensor apply for lower launch
latency. (default: False)
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
......@@ -35,10 +39,21 @@ class FusedAdam(torch.optim.Optimizer):
def __init__(self, params,
lr=1e-3, bias_correction = True,
betas=(0.9, 0.999), eps=1e-8, eps_inside_sqrt = False,
weight_decay=0., max_grad_norm=0., amsgrad=False):
weight_decay=0., max_grad_norm=0., amsgrad=False, use_mt=False,
amp_scale_adjustment=1.0):
global fused_adam_cuda
fused_adam_cuda = importlib.import_module("fused_adam_cuda")
self._use_multi_tensor = False
if use_mt:
if not multi_tensor_applier.available:
print("Warning: multi_tensor_applier is unavailable")
else:
self._use_multi_tensor = True
self._overflow_buf = torch.cuda.IntTensor([0])
self._amp_scale_adjustment = amp_scale_adjustment
if amsgrad:
raise RuntimeError('FusedAdam does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
......@@ -66,6 +81,12 @@ class FusedAdam(torch.optim.Optimizer):
if closure is not None:
loss = closure()
if hasattr(self, "_amp_stash"):
grads = self._amp_stash.grads
output_params = self._amp_stash.output_params
scale = self._amp_stash.scale*self._amp_scale_adjustment
grad_norms = self._amp_stash.grad_norms
if grads is None:
grads_group = [None]*len(self.param_groups)
# backward compatibility
......@@ -105,6 +126,12 @@ class FusedAdam(torch.optim.Optimizer):
bias_correction = 1 if group['bias_correction'] else 0
if self._use_multi_tensor:
if output_params:
tensorlists = [[],[],[],[],[]]
else:
tensorlists = [[],[],[],[]]
for p, grad, output_param in zip(group['params'], grads_this_group, output_params_this_group):
#note: p.grad should not ever be set for correct operation of mixed precision optimizer that sometimes sends None gradients
if p.grad is None and grad is None:
......@@ -130,18 +157,43 @@ class FusedAdam(torch.optim.Optimizer):
state['step'] += 1
out_p = torch.tensor([], dtype = torch.float) if output_param is None else output_param
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
pl = [p.data, exp_avg, exp_avg_sq, grad]
if output_param is not None:
pl.append(out_p)
for tl, t in zip(tensorlists, pl):
tl.append(t)
else:
fused_adam_cuda.adam(p.data,
out_p,
exp_avg,
exp_avg_sq,
grad,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
if self._use_multi_tensor:
multi_tensor_applier(
fused_adam_cuda.adam_mt,
self._overflow_buf,
tensorlists,
group['lr'],
beta1,
beta2,
group['eps'],
combined_scale,
state['step'],
self.eps_mode,
bias_correction,
group['weight_decay'])
return loss
This diff is collapsed.
......@@ -3,6 +3,9 @@
// CUDA forward declaration
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
......@@ -25,4 +28,5 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
}
......@@ -9,6 +9,10 @@
#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
#include "type_shim.h"
......@@ -55,6 +59,93 @@ __global__ void adam_cuda_kernel(
}
}
template <int DEPTH, typename T, typename GRAD_T>
struct AdamFunctor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
adamMode_t mode,
const float decay)
{
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
n -= chunk_idx*chunk_size;
T incoming_p[ILP];
T incoming_m[ILP];
T incoming_v[ILP];
T incoming_g[ILP];
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0;
incoming_m[ii] = 0;
incoming_v[ii] = 0;
incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) {
incoming_p[ii] = p[i];
incoming_m[ii] = m[i];
incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]);
}
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
}
}
}
}
};
void fused_adam_cuda(
at::Tensor & p,
at::Tensor & p_copy,
......@@ -135,3 +226,110 @@ void fused_adam_cuda(
THCudaCheck(cudaGetLastError());
}
void fused_adam_cuda_mt(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy
float lr,
float beta1,
float beta2,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay) {
//Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - std::pow(beta1, step);
const float bias_correction2 = 1 - std::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
}
else {
step_size = lr;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t tl_sz = tensor_lists.size();
AT_ASSERTM(tl_sz == 4 || tl_sz == 5, "expected tensor lists of size 4 or 5");
if (tensor_lists[3][0].type().scalarType() == at::ScalarType::Half) {
//alher values should be fp32 for half gradients
AT_ASSERTM(tensor_lists[0][0].type().scalarType() == at::ScalarType::Float, "expected parameter to be of float type");
//dich is done on the gradient type
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, accscalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, accscalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
}
} else {
if (tl_sz == 5) {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<5, scalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
} else {
AT_DISPATCH_FLOATING_TYPES(tensor_lists[3][0].type(), "adam_cuda_mt_kernel", ([&] {
multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<4, scalar_t, scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
(adamMode_t) mode,
decay);
}));
}
}
THCudaCheck(cudaGetLastError());
}
......@@ -15,15 +15,18 @@ class TestFusedAdam(unittest.TestCase):
def tearDown(self):
pass
def gen_param_optim(self, tensors, adam_option):
def gen_param_optim(self, tensors, ref_adam_option, tst_adam_option=None):
ref_param = []
tst_param = []
for tensor in tensors:
ref_param.append(torch.nn.Parameter(tensor.clone()))
tst_param.append(torch.nn.Parameter(tensor.clone()))
ref_optim = torch.optim.Adam(ref_param, **adam_option)
tst_optim = apex.optimizers.FusedAdam(tst_param, **adam_option)
ref_optim = torch.optim.Adam(ref_param, **ref_adam_option)
if tst_adam_option:
tst_optim = apex.optimizers.FusedAdam(tst_param, **tst_adam_option)
else:
tst_optim = apex.optimizers.FusedAdam(tst_param, **ref_adam_option)
return (ref_param, tst_param, ref_optim, tst_optim)
......@@ -42,8 +45,8 @@ class TestFusedAdam(unittest.TestCase):
def get_max_diff(self, ref_param, tst_param):
max_abs_diff = max_rel_diff = 0
for p_ref, p_tst in zip(ref_param, tst_param):
max_abs_diff_p = (p_ref - p_tst).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst) / p_ref).abs().max().item()
max_abs_diff_p = (p_ref - p_tst.type(p_ref.type())).abs().max().item()
max_rel_diff_p = ((p_ref - p_tst.type(p_ref.type())) / p_ref).abs().max().item()
if max_abs_diff_p > max_abs_diff: max_abs_diff = max_abs_diff_p
if max_rel_diff_p > max_rel_diff: max_rel_diff = max_rel_diff_p
......@@ -173,6 +176,34 @@ class TestFusedAdam(unittest.TestCase):
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
def test_multi_tensor(self):
sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]]
ref_adam_option = {'lr':5e-4, 'betas':(0.9, 0.999), 'eps':1e-08,
'weight_decay':0, 'amsgrad':False}
tst_adam_option = dict(ref_adam_option, **{'use_mt':True})
tensors = []
fp16_params = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device='cuda'))
fp16_params.append(torch.nn.Parameter(tensors[-1].clone().half()))
ref_param, tst_param, ref_optim, tst_optim = \
self.gen_param_optim(tensors, ref_adam_option, tst_adam_option)
for i in range(self.iters):
half_grads = self.gen_mixed_grad(ref_param, tst_param)
ref_optim.step()
tst_optim.step(grads=half_grads, output_params=fp16_params)
max_abs_diff, max_rel_diff = self.get_max_diff(ref_param, tst_param)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
max_abs_diff, max_rel_diff = self.get_max_diff(tst_param, \
fp16_params)
self.assertLessEqual(max_abs_diff, self.max_abs_diff)
self.assertLessEqual(max_rel_diff, self.max_rel_diff)
if __name__ == '__main__':
script_path = os.path.dirname(os.path.realpath(__file__))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment