Unverified Commit 37989915 authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

Merge pull request #20 from lcskrishna/ifu_06052020

IFU_06_05_2020
parents b0c7d09f 097238f8
......@@ -6,7 +6,12 @@ import torch.nn.functional as F
from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, '_jit_set_profiling_executor') :
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, '_jit_set_profiling_mode') :
torch._C._jit_set_profiling_mode(False)
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
......@@ -57,9 +62,9 @@ class EncdecMultiheadAttn(nn.Module):
self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func
elif impl == 'default' : self.attn_func = encdec_attn_func
......
......@@ -203,7 +203,7 @@ class EncdecAttnFunc(torch.autograd.Function):
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0])
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
......
import torch
import fast_self_multihead_attn
import fast_self_multihead_attn_bias
import fast_self_multihead_attn_bias_additive_mask
class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob):
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, pad_mask, mask_additive, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
if use_biases_t[0]:
if not mask_additive:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias_additive_mask.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
else:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
......@@ -38,10 +83,12 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
......@@ -53,22 +100,43 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
input_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
if use_biases_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_grads, input_weight_grads, output_weight_grads, None, None
else:
input_bias_grads = None
output_bias_grads = None
input_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_grads, input_weight_grads, output_weight_grads,input_bias_grads, output_bias_grads, None, None, None
fast_self_attn_func = FastSelfAttnFunc.apply
import torch
import fast_mask_softmax_dropout
import fast_additive_mask_softmax_dropout
class MaskSoftmaxDropout(torch.autograd.Function) :
@staticmethod
def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive])
if mask_additive:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_additive_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask if use_mask else null_tensor, \
mask_additive_t, \
dropout_prob_t)
return dropout_results.detach()
@staticmethod
def backward(ctx, output_grads):
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask, \
mask_additive_t, \
dropout_prob_t = ctx.saved_tensors
if mask_additive_t[0]:
input_grads = \
fast_additive_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
dropout_prob_t[0])
else:
input_grads = \
fast_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
pad_mask, \
dropout_prob_t[0])
return None, None, input_grads, None, None, None
fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply
......@@ -6,7 +6,12 @@ import torch.nn.functional as F
from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, '_jit_set_profiling_executor') :
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, '_jit_set_profiling_mode') :
torch._C._jit_set_profiling_mode(False)
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
......@@ -21,7 +26,7 @@ class SelfMultiheadAttn(nn.Module):
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast', separate_qkv_params=False, mask_additive=False):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
......@@ -32,17 +37,38 @@ class SelfMultiheadAttn(nn.Module):
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim**-0.5
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.separate_qkv_params = separate_qkv_params
self.mask_additive = mask_additive
if mask_additive:
assert self.include_norm_add == False, "additive mask not supported with layer norm"
assert impl == 'default' or (impl == 'fast' and bias), "additive mask not supported for fast mode without bias"
if separate_qkv_params:
self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
else:
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
if separate_qkv_params:
self.q_bias = Parameter(torch.Tensor(embed_dim))
self.k_bias = Parameter(torch.Tensor(embed_dim))
self.v_bias = Parameter(torch.Tensor(embed_dim))
else:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
self.register_parameter('in_proj_bias', None)
if separate_qkv_params:
self.register_parameter('q_bias', None)
self.register_parameter('k_bias', None)
self.register_parameter('v_bias', None)
self.q_bias = None
self.k_bias = None
self.v_bias = None
else:
self.register_parameter('in_proj_bias', None)
self.in_proj_bias = None
self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None
self.out_proj_bias = None
if self.include_norm_add:
if impl == 'fast':
......@@ -54,7 +80,7 @@ class SelfMultiheadAttn(nn.Module):
self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
......@@ -67,10 +93,20 @@ class SelfMultiheadAttn(nn.Module):
else : assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
if self.separate_qkv_params:
nn.init.xavier_uniform_(self.q_weight)
nn.init.xavier_uniform_(self.k_weight)
nn.init.xavier_uniform_(self.v_weight)
else:
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias, 0.)
if self.separate_qkv_params:
nn.init.constant_(self.q_bias, 0.)
nn.init.constant_(self.k_bias, 0.)
nn.init.constant_(self.v_bias, 0.)
else:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add:
if self.impl == 'fast':
......@@ -88,10 +124,22 @@ class SelfMultiheadAttn(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if self.separate_qkv_params:
input_weights = torch.cat([self.q_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.k_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.v_weight.view(self.num_heads,1,self.head_dim,self.embed_dim)], dim=1).reshape(3*self.embed_dim,self.embed_dim).contiguous()
else:
input_weights = self.in_proj_weight
if self.bias:
if self.separate_qkv_params:
input_bias = torch.cat([self.q_bias.view(self.num_heads,1,self.head_dim), self.k_bias.view(self.num_heads,1,self.head_dim), self.v_bias.view(self.num_heads,1,self.head_dim)],dim=1).reshape(3*self.embed_dim).contiguous()
else:
input_bias = self.in_proj_bias
else:
input_bias=None
if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None:
assert self.mask_additive == False, "additive mask not supported for time mask"
mask = attn_mask
else:
mask = None
......@@ -100,12 +148,12 @@ class SelfMultiheadAttn(nn.Module):
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
input_weights, self.out_proj_weight, mask, self.dropout)
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
input_weights, self.out_proj_weight,
input_bias, self.out_proj_bias,
mask, self.dropout)
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
......@@ -114,11 +162,11 @@ class SelfMultiheadAttn(nn.Module):
else:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
input_weights, self.out_proj_weight, input_bias, self.out_proj_bias, mask, self.mask_additive, self.dropout)
else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
mask, self.dropout)
input_weights, self.out_proj_weight,
input_bias, self.out_proj_bias,
mask, self.mask_additive, self.dropout)
return outputs,None
......@@ -264,9 +264,11 @@ class DistributedFusedAdam(torch.optim.Optimizer):
grp = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
if self._compute_L2_grad_norm:
l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
if torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = l2_grad_norm_pg
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
for rs_pg in self._rs_pg:
torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
......
This diff is collapsed.
import torch
import unittest
import torch.nn.functional as F
from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func
class FusedSoftmaxTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.mask = (torch.randn(self.sequences,self.seq_length)>0).cuda()
self.mask = self.mask.half()*-10000
self.ref_inputs = torch.randn(self.heads * self.sequences, self.seq_length, self.seq_length,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs = self.ref_inputs.clone().detach().requires_grad_(True)
def test_fused_softmax(self) :
grads = torch.randn_like(self.tst_inputs)
y_ref = self.ref_inputs.view(self.sequences, self.heads, self.seq_length, self.seq_length)
y_ref = y_ref + self.mask.unsqueeze(1).unsqueeze(2)
y_ref = y_ref.view(self.sequences*self.heads, self.seq_length, self.seq_length)
y_ref = F.softmax(y_ref, dim=-1)
y_ref = torch._fused_dropout(y_ref, 1.0)
y_tst = fast_mask_softmax_dropout_func(True, self.heads, self.tst_inputs, self.mask, True, 0.0)
y_ref[0].backward(grads)
y_tst.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(y_ref[0], y_tst, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
......@@ -51,6 +51,8 @@ class FusedLAMB(torch.optim.Optimizer):
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
......@@ -62,7 +64,7 @@ class FusedLAMB(torch.optim.Optimizer):
betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01,
amsgrad=False, adam_w_mode=True,
grad_averaging=True, set_grad_none=True,
max_grad_norm=1.0):
max_grad_norm=1.0, use_nvlamb=False):
if amsgrad:
raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
defaults = dict(lr=lr, bias_correction=bias_correction,
......@@ -72,6 +74,7 @@ class FusedLAMB(torch.optim.Optimizer):
super(FusedLAMB, self).__init__(params, defaults)
if multi_tensor_applier.available:
import amp_C
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
......@@ -80,6 +83,7 @@ class FusedLAMB(torch.optim.Optimizer):
self.adam_w_mode = 1 if adam_w_mode else 0
self.set_grad_none = set_grad_none
self.use_nvlamb = use_nvlamb
def zero_grad(self):
if self.set_grad_none:
......@@ -100,6 +104,37 @@ class FusedLAMB(torch.optim.Optimizer):
if closure is not None:
loss = closure()
# create separate grad lists for fp32 and fp16 params
g_all_32, g_all_16 = [], []
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
if p.dtype == torch.float32:
g_all_32.append(p.grad.data)
elif p.dtype == torch.float16:
g_all_16.append(p.grad.data)
else:
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
# compute grad norm for two lists
if len(g_all_32) > 0:
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_32], False)[0]
if len(g_all_16) > 0:
g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[g_all_16], False)[0]
# blend two grad norms to get global grad norm
global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm,
self._dummy_overflow_buf,
[[g_norm_32, g_norm_16]],
False)[0]
max_grad_norm = self.defaults['max_grad_norm']
for group in self.param_groups:
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']
......@@ -156,7 +191,9 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
global_grad_norm,
max_grad_norm,
self.use_nvlamb)
if(len(g_32) > 0):
multi_tensor_applier(self.multi_tensor_lamb,
self._dummy_overflow_buf,
......@@ -170,6 +207,8 @@ class FusedLAMB(torch.optim.Optimizer):
group['weight_decay'],
grad_averaging,
self.adam_w_mode,
group['max_grad_norm'])
global_grad_norm,
max_grad_norm,
self.use_nvlamb)
return loss
......@@ -204,6 +204,13 @@ def patchClass(cls):
add_wrapper(cls, f)
def init():
string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
"soon be removed from Apex. Please visit\n" + \
"https://github.com/NVIDIA/PyProf\n" + \
"for the latest version.\n\n"
# print regardless of warning state
print(string)
print("Initializing NVTX monkey patches")
for cls in [torch, torch.Tensor, torch.nn.functional,]:
patchClass(cls)
......
......@@ -42,7 +42,7 @@ void multi_tensor_lamb_stage1_cuda(
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
at::Tensor global_grad_norm,
const float max_global_grad_norm);
void multi_tensor_lamb_stage2_cuda(
......@@ -51,7 +51,9 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float step_size);
const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python);
void multi_tensor_adam_cuda(
int chunk_size,
......@@ -106,7 +108,9 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float max_grad_norm);
at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
......
......@@ -52,8 +52,8 @@ struct LAMBStage1Functor
const float epsilon,
adamMode_t mode,
const float decay,
float* global_grad_norm,
float max_global_grad_norm)
const float* global_grad_norm,
const float max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -239,7 +239,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>* tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
const float learning_rate,
const float decay,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -250,9 +252,15 @@ struct LAMBStage2Functor
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
MATH_T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl->addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
......@@ -334,12 +342,16 @@ void multi_tensor_lamb_cuda(
const float weight_decay,
const int grad_averaging,
const int mode,
const float max_grad_norm)
at::Tensor global_grad_norm,
const float max_grad_norm,
at::optional<bool> use_nvlamb_python)
{
using namespace at;
// Master weight and 32bit momentum(potentially changing) is not handled by this
// So we assume every tensor are all in the same type
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
......@@ -354,9 +366,6 @@ void multi_tensor_lamb_cuda(
std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
// Compute global grad norm
auto grad_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, false);
// Compute per tensor param norm
auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
......@@ -378,7 +387,7 @@ void multi_tensor_lamb_cuda(
epsilon,
(adamMode_t) mode,
weight_decay,
std::get<0>(grad_norm_tuple).DATA_PTR<float>(),
global_grad_norm.DATA_PTR<float>(),
max_grad_norm); )
// Compute update norms
......@@ -395,7 +404,9 @@ void multi_tensor_lamb_cuda(
LAMBStage2Functor<scalar_t_0>(),
std::get<1>(param_norm_tuple).DATA_PTR<float>(),
std::get<1>(update_norm_tuple).DATA_PTR<float>(),
lr); )
lr,
weight_decay,
use_nvlamb); )
AT_CUDA_CHECK(cudaGetLastError());
......
......@@ -118,12 +118,13 @@ void multi_tensor_lamb_stage1_cuda(
const float beta1,
const float beta2,
const float epsilon,
const float global_grad_norm,
at::Tensor global_grad_norm,
const float max_global_grad_norm)
{
using namespace at;
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
const float* g_grad_norm = global_grad_norm.DATA_PTR<float>();
float clipped_global_grad_norm = *(g_grad_norm) > max_global_grad_norm ? *(g_grad_norm) / max_global_grad_norm : 1.0f;
float next_step = float(step+1);
float beta1_correction = 1.0f - std::pow(beta1, next_step);
float beta2_correction = 1.0f - std::pow(beta2, next_step);
......
......@@ -13,6 +13,8 @@
#define BLOCK_SIZE 512
#define ILP 4
using MATH_T = float;
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T, typename UPD_T>
......@@ -24,7 +26,9 @@ struct LAMBStage2Functor
TensorListMetadata<2>* tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate)
const float learning_rate,
const float decay,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
......@@ -35,9 +39,15 @@ struct LAMBStage2Functor
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
T ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != 0.0))
{
float param_norm = per_tensor_param_norm[tensor_num];
float update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* p = (T*)tl->addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
......@@ -87,8 +97,12 @@ void multi_tensor_lamb_stage2_cuda(
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate)
const float lr,
const float weight_decay,
at::optional<bool> use_nvlamb_python)
{
bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
using namespace at;
DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
......@@ -101,7 +115,9 @@ void multi_tensor_lamb_stage2_cuda(
LAMBStage2Functor<scalar_t_0, scalar_t_1>(),
per_tensor_param_norm.DATA_PTR<float>(),
per_tensor_update_norm.DATA_PTR<float>(),
learning_rate); ))
lr,
weight_decay,
use_nvlamb); ))
AT_CUDA_CHECK(cudaGetLastError());
......
......@@ -24,7 +24,7 @@ if not torch.cuda.is_available():
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
print("torch.__version__ = ", torch.__version__)
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
......@@ -37,6 +37,11 @@ ext_modules = []
extras = {}
if "--pyprof" in sys.argv:
string = "\n\nPyprof has been moved to its own dedicated repository and will " + \
"soon be removed from Apex. Please visit\n" + \
"https://github.com/NVIDIA/PyProf\n" + \
"for the latest version."
warnings.warn(string, DeprecationWarning)
with open('requirements.txt') as f:
required_packages = f.read().splitlines()
extras['pyprof'] = required_packages
......@@ -98,6 +103,25 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--distributed_lamb" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--distributed_lamb")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--distributed_lamb was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
ext_modules.append(
CUDAExtension(name='distributed_lamb_cuda',
sources=['apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp',
'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--cuda_ext")
......@@ -293,6 +317,58 @@ if "--fast_multihead_attn" in sys.argv:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment