Unverified Commit 3f94528e authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Add Fast Multihead Attention to APEX Contrib (#697)

* Adding C++ Multihead Attention implementation to contrib.

* Add reference test that at least works for forward.

* Remove CublasLt support from multihead attention.

* Add new Python version of self attention.

* Update python model of MHA with backward pass.

* Fixed Output Linear connection in MHA.

* Clean up compiles and add documentation to PySelfAttention.

* Add Encdec Python version of multihead attention.  Cleanup files.

* Tests for self and encdec multihead attention.

* Add reference pytorch implementation of attention with norm and add.

* Add cutlass branch definition.

* Add cutlass download to compile.

* Add norm/add tests.

* Add biases to pytorch python versions.

* Add tests and fix issues with python version of attention masking.

* Create README.md

* Update README.md

* Update README.md

* Update perf test parameters.

* Update README.md

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Fix matmul1 output tensor size.  Fix tests that missed issue.
parent 494f8ab3
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
class EncdecAttnFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
input_biases_q, input_biases_kv, output_biases,
mask, dropout_prob) :
use_biases_t = Variable(torch.tensor([input_biases_q is not None]))
heads_t = Variable(torch.tensor([heads]))
scale_t = Variable(torch.tensor([scale]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads
# Input Linear GEMM Q
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
if use_biases_t[0] :
input_lin_q_results = torch.addmm(input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0,1),
beta=1., alpha=1.)
else :
input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))
input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))
# Input Linear GEMM KV
# input1: (activations) [seql_k, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])
# output: [seql_k, seqs, embed_dim*2]
# GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)
if use_biases_t[0] :
input_lin_kv_results = torch.addmm(input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0,1),
beta=1., alpha=1.)
else :
input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads, head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None :
# Self Attention Time Mask
if use_time_mask :
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask
else :
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training :
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else :
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
# this requires a backward pass specified.
# Input1: from_softmax [seqs*heads, seql_q, seql_k]
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] :
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
else :
outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))
ctx.save_for_backward(use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
head_dim = inputs_q.size(2) // heads_t[0]
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads_t[0], head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
# Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)
queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:,:,0,:]
values_grads = input_lin_kv_results_grads[:,:,1,:]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] :
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else :
output_bias_grads = None
# Matmul2 - DGRAD1
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0])
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Q Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
queries_grads = queries_grads.transpose(0,1).view(inputs_q.size(0)*inputs_q.size(1), heads_t[0]*head_dim)
input_q_grads = torch.mm(queries_grads, input_weights_q)
input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Input KV Linear GEMM - DGRAD
# input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)]
# output: [seql_k, seqs, embed_dim]
# GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)
input_lin_kv_results_grads = input_lin_kv_results_grads.view(inputs_kv.size(0)*inputs_kv.size(1), heads_t[0]*2*head_dim)
input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)
input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2))
# Input Q Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, embed_dim(1024)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [embed_dim, embed_dim]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)
input_weight_q_grads = torch.mm(queries_grads.transpose(0,1), inputs_q.view(inputs_q.size(0)*inputs_q.size(1), inputs_q.size(2)))
# Input KV Linear GEMM - WGRAD
# input1: (data grads) [seql_k*seqs, 2*embed_dim(2048)]
# input2: (activations) [seql_k*seqs, embed_dim(1024)]
# output: [2*embed_dim, embed_dim]
# GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))
if use_biases_t[0] :
input_bias_grads_q = torch.sum(queries_grads, 0)
input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)
else :
input_bias_grads_q = None
input_bias_grads_kv = None
return None, None, None, None, \
input_q_grads, input_kv_grads, \
input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \
input_bias_grads_q, input_bias_grads_kv, output_bias_grads, \
None, None
encdec_attn_func = EncdecAttnFunc.apply
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
input_lin_q_results, \
input_lin_kv_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_encdec_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
input_q_grads, \
input_kv_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads = \
fast_encdec_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads, None, None
fast_encdec_attn_func = FastEncdecAttnFunc.apply
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn_norm_add
class FastEncdecAttnNormAddFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
input_lin_q_results, \
input_lin_kv_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
dropout_add_mask, \
outputs = \
fast_encdec_multihead_attn_norm_add.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t = ctx.saved_tensors
input_q_grads, \
input_kv_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads = \
fast_encdec_multihead_attn_norm_add.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t[0])
#import pdb; pdb.set_trace()
return None, None, None, \
input_q_grads, \
input_kv_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads, \
None, None
fast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn
class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
input_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_grads, input_weight_grads, output_weight_grads, None, None
fast_self_attn_func = FastSelfAttnFunc.apply
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn_norm_add
class FastSelfAttnNormAddFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
dropout_add_mask, \
outputs = \
fast_self_multihead_attn_norm_add.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t = ctx.saved_tensors
input_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn_norm_add.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t[0])
return None, None, None, \
input_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_grads, \
output_weight_grads, \
None, None
fast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) :
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
class SelfMultiheadAttn(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.bias = bias
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim**-0.5
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
self.register_parameter('in_proj_bias', None)
self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None
self.out_proj_bias = None
if self.include_norm_add :
if impl == 'fast' :
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
else :
self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add :
if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
else :
if impl == 'fast' : self.attn_func = fast_self_attn_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight)
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add :
if self.impl == 'fast' :
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else :
self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True) :
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if key_padding_mask is not None :
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None :
mask = attn_mask
else :
mask = None
if self.include_norm_add :
if self.impl == 'fast' :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else :
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
mask, self.dropout)
if is_training :
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else :
outputs = outputs + query
else :
if self.impl == 'fast' :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
mask, self.dropout)
return outputs,None
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
class SelfAttnFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights,
input_biases, output_biases,
mask, dropout_prob) :
use_biases_t = Variable(torch.tensor([input_biases is not None]))
heads_t = Variable(torch.tensor([heads]))
scale_t = Variable(torch.tensor([scale]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads
# Input Linear GEMM
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim*3]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)
if use_biases_t[0] :
input_lin_results = torch.addmm(input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0,1),
beta=1., alpha=1.)
else :
input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads, 3, head_dim)
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None :
# Self Attention Time Mask
if use_time_mask :
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask
else :
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training :
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else :
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
# this requires a backward pass specified.
# Input1: from_softmax [seqs*heads, seql_q, seql_k]
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] :
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
else :
outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))
ctx.save_for_backward(use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0]
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads_t[0], 3, head_dim)
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
# Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_results_grads = torch.empty_like(input_lin_results)
queries_grads = input_lin_results_grads[:,:,0,:]
keys_grads = input_lin_results_grads[:,:,1,:]
values_grads = input_lin_results_grads[:,:,2,:]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] :
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else :
output_bias_grads = None
# Matmul2 - DGRAD1
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0])
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
input_lin_results_grads = input_lin_results_grads.view(inputs.size(0)*inputs.size(1), heads_t[0]*3*head_dim)
input_grads = torch.mm(input_lin_results_grads, input_weights)
input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))
# Input Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, 3*embed_dim(3072)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [3*embed_dim, embed_dim]
# GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))
if use_biases_t[0] :
input_bias_grads = torch.sum(input_lin_results_grads, 0)
else :
input_bias_grads = None
return None, None, None, None, \
input_grads, \
input_weight_grads, output_weight_grads, \
input_bias_grads, output_bias_grads, \
None, None
self_attn_func = SelfAttnFunc.apply
import torch
import unittest
from apex.contrib.multihead_attn import EncdecMultiheadAttn
class EncdecMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_encdec_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs_q)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
def test_encdec_multihead_attn_time_mask(self) :
grads = torch.randn_like(self.tst_inputs_q)
time_mask_byte = torch.triu(torch.ones(self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
time_mask_bool = time_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_bool,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_byte,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
def test_encdec_multihead_attn_pad_mask(self) :
grads = torch.randn_like(self.tst_inputs_q)
pad_mask_byte = torch.tril(torch.ones(self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
pad_mask_bool = pad_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=pad_mask_bool,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=pad_mask_byte,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.multihead_attn import EncdecMultiheadAttn
class EncdecMultiheadAttnNormAddTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_encdec_multihead_attn_norm_add(self) :
grads = torch.randn_like(self.tst_inputs_q)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn(self) :
grads = torch.randn_like(self.tst_inputs)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
def test_self_multihead_attn_time_mask(self) :
grads = torch.randn_like(self.tst_inputs)
time_mask_byte= torch.triu(torch.ones(self.tst_inputs.size(0), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
time_mask_bool= time_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_bool,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_byte,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
def test_self_multihead_attn_pad_mask(self) :
grads = torch.randn_like(self.tst_inputs)
pad_mask_byte = torch.tril(torch.ones(self.tst_inputs.size(1), self.tst_inputs.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
pad_mask_bool = pad_mask_byte.to(torch.bool)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=pad_mask_bool,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=pad_mask_byte,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
import torch
import unittest
from apex.contrib.multihead_attn import SelfMultiheadAttn
class SelfMultiheadAttnNormAddTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0
self.ref_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='default')
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
self.tst_layer = SelfMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='fast')
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()
self.tst_inputs = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
def test_self_multihead_attn_norm_add(self) :
grads = torch.randn_like(self.tst_inputs)
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs,
self.ref_inputs,
self.ref_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
tst_outputs,_ = self.tst_layer.forward(self.tst_inputs,
self.tst_inputs,
self.tst_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.ref_inputs.backward(grads)
self.tst_inputs.backward(grads)
self.assertTrue(torch.allclose(self.ref_inputs, self.tst_inputs, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs.grad, self.tst_inputs.grad, atol=1e-3, rtol=1e-3))
if __name__ == '__main__':
unittest.main()
......@@ -192,6 +192,71 @@ if "--deprecated_fused_adam" in sys.argv:
'nvcc':['-O3',
'--use_fast_math'] + version_dependent_macros}))
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
import subprocess
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
setup(
name='apex',
version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment