Commit 93f91cde authored by Kexin Yu's avatar Kexin Yu
Browse files

Merge remote-tracking branch 'upstream/master'

parents 33082d2b 80b90b9d
from .self_multihead_attn import SelfMultiheadAttn
from .encdec_multihead_attn import EncdecMultiheadAttn
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
class EncdecMultiheadAttn(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.bias = bias
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim**-0.5
self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_kv = Parameter(torch.Tensor(2*embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_kv = Parameter(torch.Tensor(2*embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
self.register_parameter('in_proj_bias_q', None)
self.register_parameter('in_proj_bias_kv', None)
self.in_proj_bias_q = None
self.in_proj_bias_kv = None
self.out_proj_bias = None
if self.include_norm_add:
if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func
elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
else:
if impl == 'fast' : self.attn_func = fast_encdec_attn_func
elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q)
nn.init.xavier_uniform_(self.in_proj_weight_kv)
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.)
nn.init.constant_(self.in_proj_bias_kv, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add:
if self.impl == 'fast' :
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else:
self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None:
mask = attn_mask
else:
mask = None
if self.include_norm_add:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout)
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else:
outputs = outputs + query
else:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout)
return outputs,None
import torch
import torch.nn.functional as F
class EncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
input_biases_q, input_biases_kv, output_biases,
mask, dropout_prob):
use_biases_t = torch.tensor([input_biases_q is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads
# Input Linear GEMM Q
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
if use_biases_t[0]:
input_lin_q_results = torch.addmm(input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0,1),
beta=1., alpha=1.)
else:
input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))
input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))
# Input Linear GEMM KV
# input1: (activations) [seql_k, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])
# output: [seql_k, seqs, embed_dim*2]
# GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)
if use_biases_t[0]:
input_lin_kv_results = torch.addmm(input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0,1),
beta=1., alpha=1.)
else:
input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads, head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None:
# Self Attention Time Mask
if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask
else:
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else:
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
# this requires a backward pass specified.
# Input1: from_softmax [seqs*heads, seql_q, seql_k]
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0]:
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
else:
outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))
ctx.save_for_backward(use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
head_dim = inputs_q.size(2) // heads_t[0]
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads_t[0], head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
# Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)
queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:,:,0,:]
values_grads = input_lin_kv_results_grads[:,:,1,:]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else:
output_bias_grads = None
# Matmul2 - DGRAD1
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0])
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Q Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
queries_grads = queries_grads.transpose(0,1).view(inputs_q.size(0)*inputs_q.size(1), heads_t[0]*head_dim)
input_q_grads = torch.mm(queries_grads, input_weights_q)
input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Input KV Linear GEMM - DGRAD
# input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)]
# output: [seql_k, seqs, embed_dim]
# GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)
input_lin_kv_results_grads = input_lin_kv_results_grads.view(inputs_kv.size(0)*inputs_kv.size(1), heads_t[0]*2*head_dim)
input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)
input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2))
# Input Q Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, embed_dim(1024)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [embed_dim, embed_dim]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)
input_weight_q_grads = torch.mm(queries_grads.transpose(0,1), inputs_q.view(inputs_q.size(0)*inputs_q.size(1), inputs_q.size(2)))
# Input KV Linear GEMM - WGRAD
# input1: (data grads) [seql_k*seqs, 2*embed_dim(2048)]
# input2: (activations) [seql_k*seqs, embed_dim(1024)]
# output: [2*embed_dim, embed_dim]
# GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))
if use_biases_t[0]:
input_bias_grads_q = torch.sum(queries_grads, 0)
input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)
else:
input_bias_grads_q = None
input_bias_grads_kv = None
return None, None, None, None, \
input_q_grads, input_kv_grads, \
input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \
input_bias_grads_q, input_bias_grads_kv, output_bias_grads, \
None, None
encdec_attn_func = EncdecAttnFunc.apply
import torch
import fast_encdec_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
input_lin_q_results, \
input_lin_kv_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_encdec_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
input_q_grads, \
input_kv_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads = \
fast_encdec_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads, None, None
fast_encdec_attn_func = FastEncdecAttnFunc.apply
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import fast_encdec_multihead_attn_norm_add
class FastEncdecAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
input_lin_q_results, \
input_lin_kv_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
dropout_add_mask, \
outputs = \
fast_encdec_multihead_attn_norm_add.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t = ctx.saved_tensors
input_q_grads, \
input_kv_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads = \
fast_encdec_multihead_attn_norm_add.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t[0])
#import pdb; pdb.set_trace()
return None, None, None, \
input_q_grads, \
input_kv_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads, \
None, None
fast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply
import torch
import fast_self_multihead_attn
class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
input_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_grads, input_weight_grads, output_weight_grads, None, None
fast_self_attn_func = FastSelfAttnFunc.apply
import torch
import fast_self_multihead_attn_norm_add
class FastSelfAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
dropout_add_mask, \
outputs = \
fast_self_multihead_attn_norm_add.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t = ctx.saved_tensors
input_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn_norm_add.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t[0])
return None, None, None, \
input_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_grads, \
output_weight_grads, \
None, None
fast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
class SelfMultiheadAttn(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.bias = bias
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim**-0.5
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None
self.out_proj_bias = None
if self.include_norm_add:
if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
else:
if impl == 'fast' : self.attn_func = fast_self_attn_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add:
if self.impl == 'fast':
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else:
self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None:
mask = attn_mask
else:
mask = None
if self.include_norm_add:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
mask, self.dropout)
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else:
outputs = outputs + query
else:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
mask, self.dropout)
return outputs,None
import torch
import torch.nn.functional as F
class SelfAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights,
input_biases, output_biases,
mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads
# Input Linear GEMM
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim*3]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)
if use_biases_t[0]:
input_lin_results = torch.addmm(input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0,1),
beta=1., alpha=1.)
else:
input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads, 3, head_dim)
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None:
# Self Attention Time Mask
if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask
else:
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else:
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
# this requires a backward pass specified.
# Input1: from_softmax [seqs*heads, seql_q, seql_k]
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0]:
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
else:
outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))
ctx.save_for_backward(use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0]
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads_t[0], 3, head_dim)
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
# Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_results_grads = torch.empty_like(input_lin_results)
queries_grads = input_lin_results_grads[:,:,0,:]
keys_grads = input_lin_results_grads[:,:,1,:]
values_grads = input_lin_results_grads[:,:,2,:]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else:
output_bias_grads = None
# Matmul2 - DGRAD1
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0])
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
input_lin_results_grads = input_lin_results_grads.view(inputs.size(0)*inputs.size(1), heads_t[0]*3*head_dim)
input_grads = torch.mm(input_lin_results_grads, input_weights)
input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))
# Input Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, 3*embed_dim(3072)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [3*embed_dim, embed_dim]
# GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))
if use_biases_t[0]:
input_bias_grads = torch.sum(input_lin_results_grads, 0)
else:
input_bias_grads = None
return None, None, None, None, \
input_grads, \
input_weight_grads, output_weight_grads, \
input_bias_grads, output_bias_grads, \
None, None
self_attn_func = SelfAttnFunc.apply
import torch
import unittest
from apex.contrib.multihead_attn import EncdecMultiheadAttn
class EncdecMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_encdec_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs_q)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
def test_encdec_multihead_attn_time_mask(self) :
grads = torch.randn_like(self.tst_inputs_q)
time_mask_byte = torch.triu(torch.ones(self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
time_mask_bool = time_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_bool,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_byte,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
def test_encdec_multihead_attn_pad_mask(self) :
grads = torch.randn_like(self.tst_inputs_q)
pad_mask_byte = torch.tril(torch.ones(self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
pad_mask_bool = pad_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=pad_mask_bool,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=pad_mask_byte,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.multihead_attn import EncdecMultiheadAttn
class EncdecMultiheadAttnNormAddTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_encdec_multihead_attn_norm_add(self) :
grads = torch.randn_like(self.tst_inputs_q)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
def test_self_multihead_attn_time_mask(self) :
grads = torch.randn_like(self.tst_inputs)
time_mask_byte= torch.triu(torch.ones(self.tst_inputs.size(0), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
time_mask_bool= time_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_bool,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_byte,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
def test_self_multihead_attn_pad_mask(self) :
grads = torch.randn_like(self.tst_inputs)
pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
pad_mask_bool = pad_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=pad_mask_bool,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=pad_mask_byte,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnNormAddTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn_norm_add(self) :
grads = torch.randn_like(self.tst_inputs)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
......@@ -56,7 +56,11 @@ void multi_tensor_apply(
for(int t = 0; t < tensor_lists[l].size(); t++)
{
// TODO: Print which tensor fails.
TORCH_CHECK(tensor_lists[l][t].is_contiguous(), "A tensor was not contiguous.");
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
......
......@@ -29,13 +29,13 @@ With the new Amp API **you never need to explicitly convert your model, or the i
"Pure FP32" training:
```
$ python main_amp.py --opt-level O0
$ python main_amp.py --opt_level O0
```
Recommended mixed precision training:
```
$ python main_amp.py --opt-level O1
$ python main_amp.py --opt_level O1
```
Have a look at the original [DCGAN example](https://github.com/pytorch/examples/tree/master/dcgan) for more information about the used arguments.
To enable mixed precision training, we introduce the `--opt-level` argument.
To enable mixed precision training, we introduce the `--opt_level` argument.
......@@ -6,6 +6,9 @@ import sys
import warnings
import os
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
if not torch.cuda.is_available():
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
......@@ -88,7 +91,10 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0):
version_ge_1_3 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
version_ge_1_3 = ['-DVERSION_GE_1_3']
version_dependent_macros = version_ge_1_1 + version_ge_1_3
version_ge_1_5 = []
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
if "--cuda_ext" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
......@@ -148,7 +154,7 @@ if "--bnp" in sys.argv:
'apex/contrib/csrc/groupbn/ipc.cu',
'apex/contrib/csrc/groupbn/interface.cpp',
'apex/contrib/csrc/groupbn/batch_norm_add_relu.cu'],
include_dirs=['csrc'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': [] + version_dependent_macros,
'nvcc':['-DCUDA_HAS_FP16=1',
'-D__CUDA_NO_HALF_OPERATORS__',
......@@ -169,7 +175,7 @@ if "--xentropy" in sys.argv:
CUDAExtension(name='xentropy_cuda',
sources=['apex/contrib/csrc/xentropy/interface.cpp',
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'],
include_dirs=['csrc'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
......@@ -187,9 +193,73 @@ if "--deprecated_fused_adam" in sys.argv:
CUDAExtension(name='fused_adam_cuda',
sources=['apex/contrib/csrc/optimizers/fused_adam_cuda.cpp',
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'],
include_dirs=['csrc'],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension.with_options(use_ninja=False)
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
setup(
......
......@@ -7,6 +7,7 @@ from apex import amp
import torch
from torch import nn
import torch.nn.functional as F
from math import floor
from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
......@@ -20,6 +21,10 @@ except ImportError as err:
print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
disabled = True
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
try_nhwc = (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4)
class TestMultiTensorAxpby(unittest.TestCase):
......@@ -31,28 +36,36 @@ class TestMultiTensorAxpby(unittest.TestCase):
self.xval = 4.0
self.yval = 16.0
self.overflow_buf = torch.cuda.IntTensor(1).zero_()
self.ref = torch.cuda.FloatTensor([136.0])
self.ref = torch.full((1,), 136.0, device="cuda", dtype=torch.float32)
def tearDown(self):
pass
# The tensor creation here is written for convenience, not speed.
def axpby(self, sizea, sizeb, applier, repeat_tensors,
x_type, y_type, out_type, inplace=False):
x_type, y_type, out_type, inplace=False, nhwc=False):
self.overflow_buf.zero_()
t1 = torch.cuda.FloatTensor(sizea).fill_(1.0)
t2 = torch.cuda.FloatTensor(sizeb).fill_(1.0)
sizea = sizea if isinstance(sizea, tuple) else (sizea,)
sizeb = sizeb if isinstance(sizeb, tuple) else (sizeb,)
t1 = torch.full(sizea, 1.0, device="cuda", dtype=torch.float32)
t2 = torch.full(sizeb, 1.0, device="cuda", dtype=torch.float32)
def to_fmt(t, tp):
if nhwc:
return t.clone().to(tp, memory_format=torch.channels_last)
else:
return t.clone().to(tp)
y_list = []
for i in range(repeat_tensors):
y_list += [t1.clone().to(y_type)*self.yval, t2.clone().to(y_type)*self.yval]
y_list += [to_fmt(t1, y_type)*self.yval, to_fmt(t2, y_type)*self.yval]
x_list = [x.clone().to(x_type)*(self.xval/self.yval) for x in y_list]
x_list = [to_fmt(x, x_type)*(self.xval/self.yval) for x in y_list]
if inplace:
out_list = y_list
else:
out_list = [out.clone().to(out_type)*3.0 for out in y_list]
out_list = [to_fmt(out, out_type)*3.0 for out in y_list]
applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1)
......@@ -122,6 +135,45 @@ class TestMultiTensorAxpby(unittest.TestCase):
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
@unittest.skipIf(disabled, "amp_C is unavailable")
@unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc")
def test_fuzz_nhwc(self):
input_size_pairs = (
((7, 77, 7, 77), (5, 55, 5, 55)),
((1, 1, 777, 1), (1, 1, 555, 1)),
((5, 47, 5, 55), (1, 1, 1, 2048*32 + 1)),
((1, 1, 1, 2048*32 + 1), (55, 47, 5, 55)),
((555, 1, 1, 1), (32, 8, 32, 8)),
((32, 8, 32, 8), (55, 47, 5, 55)),
((1, 1, 33333, 1), (55, 47, 55, 5)),
((55, 47, 55, 5), (1, 1, 33333, 1)))
appliers = (
MultiTensorApply(2048*32),
MultiTensorApply(333),
MultiTensorApply(33333))
repeat_tensors = (
1,
55)
for sizea, sizeb in input_size_pairs:
for applier in appliers:
for repeat in repeat_tensors:
for x_type in (torch.float32, torch.float16):
for y_type in (torch.float32, torch.float16):
for out_type in (torch.float32, torch.float16):
for inplace in (True, False):
if inplace is True and (y_type is not out_type):
continue
else:
self.axpby(sizea, sizeb, applier, repeat,
x_type, y_type, out_type, inplace=inplace, nhwc=True)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 0, 0, float('nan'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
# self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
# 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment