Unverified Commit 1733946a authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Change to Multihead Attention to allow Batched GEMMs larger than 64K. (#728)

* Adding C++ Multihead Attention implementation to contrib.

* Add reference test that at least works for forward.

* Remove CublasLt support from multihead attention.

* Add new Python version of self attention.

* Update python model of MHA with backward pass.

* Fixed Output Linear connection in MHA.

* Clean up compiles and add documentation to PySelfAttention.

* Add Encdec Python version of multihead attention.  Cleanup files.

* Tests for self and encdec multihead attention.

* Add reference pytorch implementation of attention with norm and add.

* Add cutlass branch definition.

* Add cutlass download to compile.

* Add norm/add tests.

* Add biases to pytorch python versions.

* Add tests and fix issues with python version of attention masking.

* Create README.md

* Update README.md

* Update README.md

* Update perf test parameters.

* Update README.md

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Fix matmul1 output tensor size.  Fix tests that missed issue.

* Allow for Z dimensions of 64K and greater on batched GEMMs.

* remove redundant imports

* general cleanup, remove deprecated or unused functions
parent 50338df6
...@@ -100,9 +100,48 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, ...@@ -100,9 +100,48 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object."); AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel. // batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits.
THCudaCheck(Gemm::launch(params)); // To implement batched GEMM with larger batch size, we fragment it into
// smaller batched GEMMs of gridDim.z <= 64k
long batchesLeft = batchCount;
long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
do {
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int result = params.initialize(
m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a,
lda,
strideA, // distance in memory between the first element of neighboring batch
b,
ldb,
strideB, // distance in memory between the first element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first element of neighboring batch
c, // destination matrix C (may be different memory than source C matrix)
ldc,
strideC, // distance in memory between the first element of neighboring batch
iterBatchCount
);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params));
// Update batched GEMM params based on completed work
batchesLeft = batchesLeft - iterBatchCount;
a += iterBatchCount * strideA;
b += iterBatchCount * strideB;
c += iterBatchCount * strideC;;
iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
} while(batchesLeft > 0);
} }
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
......
...@@ -2,19 +2,20 @@ import torch ...@@ -2,19 +2,20 @@ import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
from .encdec_multihead_attn_func import encdec_attn_func from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) : def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True) out = F.dropout(x, p=prob, training=True)
out = residual + out out = residual + out
return out return out
class EncdecMultiheadAttn(nn.Module): class EncdecMultiheadAttn(nn.Module):
"""Multi-headed attention. """Multi-headed attention.
...@@ -46,12 +47,12 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -46,12 +47,12 @@ class EncdecMultiheadAttn(nn.Module):
self.in_proj_bias_q = None self.in_proj_bias_q = None
self.in_proj_bias_kv = None self.in_proj_bias_kv = None
self.out_proj_bias = None self.out_proj_bias = None
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None self.lyr_nrm = None
else : else:
self.register_parameter('lyr_norm_gamma_weights', None) self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None) self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None self.lyr_nrm_gamma_weights = None
...@@ -59,11 +60,11 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -59,11 +60,11 @@ class EncdecMultiheadAttn(nn.Module):
self.lyr_nrm = torch.nn.LayerNorm(embed_dim) self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters() self.reset_parameters()
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func
elif impl == 'default' : self.attn_func = encdec_attn_func elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
else : else:
if impl == 'fast' : self.attn_func = fast_encdec_attn_func if impl == 'fast' : self.attn_func = fast_encdec_attn_func
elif impl == 'default' : self.attn_func = encdec_attn_func elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
...@@ -76,14 +77,14 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -76,14 +77,14 @@ class EncdecMultiheadAttn(nn.Module):
nn.init.constant_(self.in_proj_bias_q, 0.) nn.init.constant_(self.in_proj_bias_q, 0.)
nn.init.constant_(self.in_proj_bias_kv, 0.) nn.init.constant_(self.in_proj_bias_kv, 0.)
nn.init.constant_(self.out_proj_bias, 0.) nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast' :
nn.init.ones_(self.lyr_nrm_gamma_weights) nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights) nn.init.zeros_(self.lyr_nrm_beta_weights)
else : else:
self.lyr_nrm.reset_parameters() self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True) : def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for Self-attention can be implemented by passing in the same arguments for
...@@ -93,37 +94,37 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -93,37 +94,37 @@ class EncdecMultiheadAttn(nn.Module):
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
if key_padding_mask is not None : if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!" assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask mask = key_padding_mask
elif attn_mask is not None : elif attn_mask is not None:
mask = attn_mask mask = attn_mask
else : else:
mask = None mask = None
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout) self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else : else:
lyr_nrm_results = self.lyr_nrm(query) lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias, self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
if is_training : if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else : else:
outputs = outputs + query outputs = outputs + query
else : else:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout) self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else : else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias, self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
return outputs,None return outputs,None
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
class EncdecAttnFunc(torch.autograd.Function) :
class EncdecAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv, def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights, input_weights_q, input_weights_kv, output_weights,
input_biases_q, input_biases_kv, output_biases, input_biases_q, input_biases_kv, output_biases,
mask, dropout_prob) : mask, dropout_prob):
use_biases_t = Variable(torch.tensor([input_biases_q is not None])) use_biases_t = torch.tensor([input_biases_q is not None])
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
scale_t = Variable(torch.tensor([scale])) scale_t = torch.tensor([scale])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads head_dim = inputs_q.size(2) // heads
# Input Linear GEMM Q # Input Linear GEMM Q
# input1: (activations) [seql_q, seqs, embed_dim(1024)] # input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1]) # input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim] # output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
if use_biases_t[0] : if use_biases_t[0]:
input_lin_q_results = torch.addmm(input_biases_q, input_lin_q_results = torch.addmm(input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0,1), input_weights_q.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1)) input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))
input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0)) input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))
# Input Linear GEMM KV # Input Linear GEMM KV
...@@ -35,15 +33,15 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -35,15 +33,15 @@ class EncdecAttnFunc(torch.autograd.Function) :
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1]) # input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])
# output: [seql_k, seqs, embed_dim*2] # output: [seql_k, seqs, embed_dim*2]
# GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2) # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)
if use_biases_t[0] : if use_biases_t[0]:
input_lin_kv_results = torch.addmm(input_biases_kv, input_lin_kv_results = torch.addmm(input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0,1), input_weights_kv.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1)) input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0)) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)] # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
...@@ -52,7 +50,7 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -52,7 +50,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)
keys = input_lin_kv_results[:,:,0,:] keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:] values = input_lin_kv_results[:,:,1,:]
# Matmul1 Batched GEMMs # Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
...@@ -64,15 +62,15 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -64,15 +62,15 @@ class EncdecAttnFunc(torch.autograd.Function) :
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda')) matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0]) matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None : if mask is not None:
# Self Attention Time Mask # Self Attention Time Mask
if use_time_mask : if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!" assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!" assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool) mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf')) matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask # Key Padding Mask
else : else:
batches,seql_q,seql_k = matmul1_results.size() batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads) seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
...@@ -83,12 +81,12 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -83,12 +81,12 @@ class EncdecAttnFunc(torch.autograd.Function) :
softmax_results = F.softmax(matmul1_results, dim=-1) softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference # Dropout - is not executed for inference
if is_training : if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0])) dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else : else:
dropout_results = softmax_results dropout_results = softmax_results
dropout_mask = null_tensor dropout_mask = null_tensor
# Matmul2 Batched GEMMs # Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output. # The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified, # Given that pytorch cannot currently perform autograd with an output tensor specified,
...@@ -100,18 +98,18 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -100,18 +98,18 @@ class EncdecAttnFunc(torch.autograd.Function) :
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0) matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results) matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2)) matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Output Linear GEMM # Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] : if use_biases_t[0]:
outputs = torch.addmm(output_biases, outputs = torch.addmm(output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0,1), output_weights.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1)) outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0)) outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))
...@@ -130,11 +128,11 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -130,11 +128,11 @@ class EncdecAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
use_biases_t, \ use_biases_t, \
heads_t, \ heads_t, \
scale_t, \ scale_t, \
...@@ -150,9 +148,9 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -150,9 +148,9 @@ class EncdecAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
head_dim = inputs_q.size(2) // heads_t[0] head_dim = inputs_q.size(2) // heads_t[0]
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)] # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
...@@ -161,7 +159,7 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -161,7 +159,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:,:,0,:] keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:] values = input_lin_kv_results[:,:,1,:]
# Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!) # Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs. # The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads. # The tensor is declared before hand to properly slice out query, key, and value grads.
...@@ -169,7 +167,7 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -169,7 +167,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
queries_grads = torch.empty_like(queries) queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:,:,0,:] keys_grads = input_lin_kv_results_grads[:,:,0,:]
values_grads = input_lin_kv_results_grads[:,:,1,:] values_grads = input_lin_kv_results_grads[:,:,1,:]
# Output Linear GEMM - DGRAD # Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] # Input2: (weights) [ embed_dim, embed_dim ]
...@@ -182,13 +180,13 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -182,13 +180,13 @@ class EncdecAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_q*seqs, embed_dim ] # Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1), output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] : if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else : else:
output_bias_grads = None output_bias_grads = None
# Matmul2 - DGRAD1 # Matmul2 - DGRAD1
...@@ -215,14 +213,14 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -215,14 +213,14 @@ class EncdecAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1), queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2 # Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1), keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Q Linear GEMM - DGRAD # Input Q Linear GEMM - DGRAD
...@@ -253,11 +251,11 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -253,11 +251,11 @@ class EncdecAttnFunc(torch.autograd.Function) :
# output: [2*embed_dim, embed_dim] # output: [2*embed_dim, embed_dim]
# GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim) # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2))) input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))
if use_biases_t[0] : if use_biases_t[0]:
input_bias_grads_q = torch.sum(queries_grads, 0) input_bias_grads_q = torch.sum(queries_grads, 0)
input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0) input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)
else : else:
input_bias_grads_q = None input_bias_grads_q = None
input_bias_grads_kv = None input_bias_grads_kv = None
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn import fast_encdec_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function) :
class FastEncdecAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -51,7 +47,7 @@ class FastEncdecAttnFunc(torch.autograd.Function) : ...@@ -51,7 +47,7 @@ class FastEncdecAttnFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
...@@ -6,18 +6,14 @@ ...@@ -6,18 +6,14 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn_norm_add import fast_encdec_multihead_attn_norm_add
class FastEncdecAttnNormAddFunc(torch.autograd.Function) :
class FastEncdecAttnNormAddFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -70,7 +66,7 @@ class FastEncdecAttnNormAddFunc(torch.autograd.Function) : ...@@ -70,7 +66,7 @@ class FastEncdecAttnNormAddFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn import fast_self_multihead_attn
class FastSelfAttnFunc(torch.autograd.Function) : class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -45,7 +41,7 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -45,7 +41,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn_norm_add import fast_self_multihead_attn_norm_add
class FastSelfAttnNormAddFunc(torch.autograd.Function) :
class FastSelfAttnNormAddFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -57,7 +53,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function) : ...@@ -57,7 +53,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
...@@ -2,19 +2,20 @@ import torch ...@@ -2,19 +2,20 @@ import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
from .self_multihead_attn_func import self_attn_func from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) : def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True) out = F.dropout(x, p=prob, training=True)
out = residual + out out = residual + out
return out return out
class SelfMultiheadAttn(nn.Module): class SelfMultiheadAttn(nn.Module):
"""Multi-headed attention. """Multi-headed attention.
...@@ -43,12 +44,12 @@ class SelfMultiheadAttn(nn.Module): ...@@ -43,12 +44,12 @@ class SelfMultiheadAttn(nn.Module):
self.register_parameter('out_proj_bias', None) self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None self.in_proj_bias = None
self.out_proj_bias = None self.out_proj_bias = None
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None self.lyr_nrm = None
else : else:
self.register_parameter('lyr_norm_gamma_weights', None) self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None) self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None self.lyr_nrm_gamma_weights = None
...@@ -56,11 +57,11 @@ class SelfMultiheadAttn(nn.Module): ...@@ -56,11 +57,11 @@ class SelfMultiheadAttn(nn.Module):
self.lyr_nrm = torch.nn.LayerNorm(embed_dim) self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters() self.reset_parameters()
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func
elif impl == 'default' : self.attn_func = self_attn_func elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
else : else:
if impl == 'fast' : self.attn_func = fast_self_attn_func if impl == 'fast' : self.attn_func = fast_self_attn_func
elif impl == 'default' : self.attn_func = self_attn_func elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
...@@ -71,14 +72,14 @@ class SelfMultiheadAttn(nn.Module): ...@@ -71,14 +72,14 @@ class SelfMultiheadAttn(nn.Module):
if self.bias: if self.bias:
nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.) nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast':
nn.init.ones_(self.lyr_nrm_gamma_weights) nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights) nn.init.zeros_(self.lyr_nrm_beta_weights)
else : else:
self.lyr_nrm.reset_parameters() self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True) : def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for Self-attention can be implemented by passing in the same arguments for
...@@ -87,36 +88,36 @@ class SelfMultiheadAttn(nn.Module): ...@@ -87,36 +88,36 @@ class SelfMultiheadAttn(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
if key_padding_mask is not None : if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!" assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask mask = key_padding_mask
elif attn_mask is not None : elif attn_mask is not None:
mask = attn_mask mask = attn_mask
else : else:
mask = None mask = None
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else : else:
lyr_nrm_results = self.lyr_nrm(query) lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight, self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, self.in_proj_bias, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
if is_training : if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else : else:
outputs = outputs + query outputs = outputs + query
else : else:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else : else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight, self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, self.in_proj_bias, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
class SelfAttnFunc(torch.autograd.Function) : class SelfAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs, def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights, input_weights, output_weights,
input_biases, output_biases, input_biases, output_biases,
mask, dropout_prob) : mask, dropout_prob):
use_biases_t = Variable(torch.tensor([input_biases is not None])) use_biases_t = torch.tensor([input_biases is not None])
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
scale_t = Variable(torch.tensor([scale])) scale_t = torch.tensor([scale])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads head_dim = inputs.size(2) // heads
# Input Linear GEMM # Input Linear GEMM
# input1: (activations) [seql_q, seqs, embed_dim(1024)] # input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1]) # input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim*3] # output: [seql_q, seqs, embed_dim*3]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3) # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)
if use_biases_t[0] : if use_biases_t[0]:
input_lin_results = torch.addmm(input_biases, input_lin_results = torch.addmm(input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0,1), input_weights.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1)) input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0)) input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)] # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
...@@ -39,10 +36,10 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -39,10 +36,10 @@ class SelfAttnFunc(torch.autograd.Function) :
queries = input_lin_results[:,:,0,:] queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:] keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:] values = input_lin_results[:,:,2,:]
# Matmul1 Batched GEMMs # Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation. # a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1) # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1) # Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
...@@ -51,15 +48,15 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -51,15 +48,15 @@ class SelfAttnFunc(torch.autograd.Function) :
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda')) matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0]) matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None : if mask is not None:
# Self Attention Time Mask # Self Attention Time Mask
if use_time_mask : if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!" assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!" assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool) mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf')) matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask # Key Padding Mask
else : else:
batches,seql_q,seql_k = matmul1_results.size() batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads) seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
...@@ -70,12 +67,12 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -70,12 +67,12 @@ class SelfAttnFunc(torch.autograd.Function) :
softmax_results = F.softmax(matmul1_results, dim=-1) softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference # Dropout - is not executed for inference
if is_training : if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0])) dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else : else:
dropout_results = softmax_results dropout_results = softmax_results
dropout_mask = null_tensor dropout_mask = null_tensor
# Matmul2 Batched GEMMs # Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output. # The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified, # Given that pytorch cannot currently perform autograd with an output tensor specified,
...@@ -87,18 +84,18 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -87,18 +84,18 @@ class SelfAttnFunc(torch.autograd.Function) :
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0) matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results) matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2)) matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
# Output Linear GEMM # Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] : if use_biases_t[0]:
outputs = torch.addmm(output_biases, outputs = torch.addmm(output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0,1), output_weights.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1)) outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0)) outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))
...@@ -114,11 +111,11 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -114,11 +111,11 @@ class SelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
use_biases_t, \ use_biases_t, \
heads_t, \ heads_t, \
scale_t, \ scale_t, \
...@@ -131,9 +128,9 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -131,9 +128,9 @@ class SelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0] head_dim = inputs.size(2) // heads_t[0]
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)] # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
...@@ -142,7 +139,7 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -142,7 +139,7 @@ class SelfAttnFunc(torch.autograd.Function) :
queries = input_lin_results[:,:,0,:] queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:] keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:] values = input_lin_results[:,:,2,:]
# Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!) # Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs. # The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads. # The tensor is declared before hand to properly slice out query, key, and value grads.
...@@ -150,7 +147,7 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -150,7 +147,7 @@ class SelfAttnFunc(torch.autograd.Function) :
queries_grads = input_lin_results_grads[:,:,0,:] queries_grads = input_lin_results_grads[:,:,0,:]
keys_grads = input_lin_results_grads[:,:,1,:] keys_grads = input_lin_results_grads[:,:,1,:]
values_grads = input_lin_results_grads[:,:,2,:] values_grads = input_lin_results_grads[:,:,2,:]
# Output Linear GEMM - DGRAD # Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] # Input2: (weights) [ embed_dim, embed_dim ]
...@@ -163,13 +160,13 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -163,13 +160,13 @@ class SelfAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_q*seqs, embed_dim ] # Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1), output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1) output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] : if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else : else:
output_bias_grads = None output_bias_grads = None
# Matmul2 - DGRAD1 # Matmul2 - DGRAD1
...@@ -196,14 +193,14 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -196,14 +193,14 @@ class SelfAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1), queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2 # Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1), keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Linear GEMM - DGRAD # Input Linear GEMM - DGRAD
...@@ -221,9 +218,9 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -221,9 +218,9 @@ class SelfAttnFunc(torch.autograd.Function) :
# GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim) # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2))) input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))
if use_biases_t[0] : if use_biases_t[0]:
input_bias_grads = torch.sum(input_lin_results_grads, 0) input_bias_grads = torch.sum(input_lin_results_grads, 0)
else : else:
input_bias_grads = None input_bias_grads = None
return None, None, None, None, \ return None, None, None, None, \
......
...@@ -205,7 +205,6 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -205,7 +205,6 @@ if "--fast_multihead_attn" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
import subprocess
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
...@@ -260,6 +259,70 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -260,6 +259,70 @@ if "--fast_multihead_attn" in sys.argv:
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
setup( setup(
name='apex', name='apex',
version='0.1', version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment