"vscode:/vscode.git/clone" did not exist on "cbb901ac51bd6c41e4243ffb936ef0e2f7ca8ada"
Unverified Commit 1733946a authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Change to Multihead Attention to allow Batched GEMMs larger than 64K. (#728)

* Adding C++ Multihead Attention implementation to contrib.

* Add reference test that at least works for forward.

* Remove CublasLt support from multihead attention.

* Add new Python version of self attention.

* Update python model of MHA with backward pass.

* Fixed Output Linear connection in MHA.

* Clean up compiles and add documentation to PySelfAttention.

* Add Encdec Python version of multihead attention.  Cleanup files.

* Tests for self and encdec multihead attention.

* Add reference pytorch implementation of attention with norm and add.

* Add cutlass branch definition.

* Add cutlass download to compile.

* Add norm/add tests.

* Add biases to pytorch python versions.

* Add tests and fix issues with python version of attention masking.

* Create README.md

* Update README.md

* Update README.md

* Update perf test parameters.

* Update README.md

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Fix matmul1 output tensor size.  Fix tests that missed issue.

* Allow for Z dimensions of 64K and greater on batched GEMMs.

* remove redundant imports

* general cleanup, remove deprecated or unused functions
parent 50338df6
...@@ -100,9 +100,48 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, ...@@ -100,9 +100,48 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object."); AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel. // batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits.
THCudaCheck(Gemm::launch(params)); // To implement batched GEMM with larger batch size, we fragment it into
// smaller batched GEMMs of gridDim.z <= 64k
long batchesLeft = batchCount;
long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
do {
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int result = params.initialize(
m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a,
lda,
strideA, // distance in memory between the first element of neighboring batch
b,
ldb,
strideB, // distance in memory between the first element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first element of neighboring batch
c, // destination matrix C (may be different memory than source C matrix)
ldc,
strideC, // distance in memory between the first element of neighboring batch
iterBatchCount
);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params));
// Update batched GEMM params based on completed work
batchesLeft = batchesLeft - iterBatchCount;
a += iterBatchCount * strideA;
b += iterBatchCount * strideB;
c += iterBatchCount * strideC;;
iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
} while(batchesLeft > 0);
} }
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
......
...@@ -2,19 +2,20 @@ import torch ...@@ -2,19 +2,20 @@ import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
from .encdec_multihead_attn_func import encdec_attn_func from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) : def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True) out = F.dropout(x, p=prob, training=True)
out = residual + out out = residual + out
return out return out
class EncdecMultiheadAttn(nn.Module): class EncdecMultiheadAttn(nn.Module):
"""Multi-headed attention. """Multi-headed attention.
...@@ -46,12 +47,12 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -46,12 +47,12 @@ class EncdecMultiheadAttn(nn.Module):
self.in_proj_bias_q = None self.in_proj_bias_q = None
self.in_proj_bias_kv = None self.in_proj_bias_kv = None
self.out_proj_bias = None self.out_proj_bias = None
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None self.lyr_nrm = None
else : else:
self.register_parameter('lyr_norm_gamma_weights', None) self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None) self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None self.lyr_nrm_gamma_weights = None
...@@ -59,11 +60,11 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -59,11 +60,11 @@ class EncdecMultiheadAttn(nn.Module):
self.lyr_nrm = torch.nn.LayerNorm(embed_dim) self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters() self.reset_parameters()
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func
elif impl == 'default' : self.attn_func = encdec_attn_func elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
else : else:
if impl == 'fast' : self.attn_func = fast_encdec_attn_func if impl == 'fast' : self.attn_func = fast_encdec_attn_func
elif impl == 'default' : self.attn_func = encdec_attn_func elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
...@@ -76,14 +77,14 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -76,14 +77,14 @@ class EncdecMultiheadAttn(nn.Module):
nn.init.constant_(self.in_proj_bias_q, 0.) nn.init.constant_(self.in_proj_bias_q, 0.)
nn.init.constant_(self.in_proj_bias_kv, 0.) nn.init.constant_(self.in_proj_bias_kv, 0.)
nn.init.constant_(self.out_proj_bias, 0.) nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast' :
nn.init.ones_(self.lyr_nrm_gamma_weights) nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights) nn.init.zeros_(self.lyr_nrm_beta_weights)
else : else:
self.lyr_nrm.reset_parameters() self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True) : def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for Self-attention can be implemented by passing in the same arguments for
...@@ -93,37 +94,37 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -93,37 +94,37 @@ class EncdecMultiheadAttn(nn.Module):
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
if key_padding_mask is not None : if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!" assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask mask = key_padding_mask
elif attn_mask is not None : elif attn_mask is not None:
mask = attn_mask mask = attn_mask
else : else:
mask = None mask = None
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout) self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else : else:
lyr_nrm_results = self.lyr_nrm(query) lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias, self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
if is_training : if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else : else:
outputs = outputs + query outputs = outputs + query
else : else:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout) self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else : else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias, self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
return outputs,None return outputs,None
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
class EncdecAttnFunc(torch.autograd.Function) :
class EncdecAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv, def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights, input_weights_q, input_weights_kv, output_weights,
input_biases_q, input_biases_kv, output_biases, input_biases_q, input_biases_kv, output_biases,
mask, dropout_prob) : mask, dropout_prob):
use_biases_t = Variable(torch.tensor([input_biases_q is not None])) use_biases_t = torch.tensor([input_biases_q is not None])
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
scale_t = Variable(torch.tensor([scale])) scale_t = torch.tensor([scale])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads head_dim = inputs_q.size(2) // heads
# Input Linear GEMM Q # Input Linear GEMM Q
# input1: (activations) [seql_q, seqs, embed_dim(1024)] # input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1]) # input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim] # output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim) # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
if use_biases_t[0] : if use_biases_t[0]:
input_lin_q_results = torch.addmm(input_biases_q, input_lin_q_results = torch.addmm(input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0,1), input_weights_q.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1)) input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))
input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0)) input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))
# Input Linear GEMM KV # Input Linear GEMM KV
...@@ -35,15 +33,15 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -35,15 +33,15 @@ class EncdecAttnFunc(torch.autograd.Function) :
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1]) # input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])
# output: [seql_k, seqs, embed_dim*2] # output: [seql_k, seqs, embed_dim*2]
# GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2) # GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)
if use_biases_t[0] : if use_biases_t[0]:
input_lin_kv_results = torch.addmm(input_biases_kv, input_lin_kv_results = torch.addmm(input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0,1), input_weights_kv.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1)) input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0)) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)] # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
...@@ -52,7 +50,7 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -52,7 +50,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)
keys = input_lin_kv_results[:,:,0,:] keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:] values = input_lin_kv_results[:,:,1,:]
# Matmul1 Batched GEMMs # Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
...@@ -64,15 +62,15 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -64,15 +62,15 @@ class EncdecAttnFunc(torch.autograd.Function) :
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda')) matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0]) matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None : if mask is not None:
# Self Attention Time Mask # Self Attention Time Mask
if use_time_mask : if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!" assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!" assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool) mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf')) matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask # Key Padding Mask
else : else:
batches,seql_q,seql_k = matmul1_results.size() batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads) seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
...@@ -83,12 +81,12 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -83,12 +81,12 @@ class EncdecAttnFunc(torch.autograd.Function) :
softmax_results = F.softmax(matmul1_results, dim=-1) softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference # Dropout - is not executed for inference
if is_training : if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0])) dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else : else:
dropout_results = softmax_results dropout_results = softmax_results
dropout_mask = null_tensor dropout_mask = null_tensor
# Matmul2 Batched GEMMs # Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output. # The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified, # Given that pytorch cannot currently perform autograd with an output tensor specified,
...@@ -100,18 +98,18 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -100,18 +98,18 @@ class EncdecAttnFunc(torch.autograd.Function) :
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0) matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results) matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2)) matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Output Linear GEMM # Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] : if use_biases_t[0]:
outputs = torch.addmm(output_biases, outputs = torch.addmm(output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0,1), output_weights.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1)) outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0)) outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))
...@@ -130,11 +128,11 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -130,11 +128,11 @@ class EncdecAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
use_biases_t, \ use_biases_t, \
heads_t, \ heads_t, \
scale_t, \ scale_t, \
...@@ -150,9 +148,9 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -150,9 +148,9 @@ class EncdecAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
head_dim = inputs_q.size(2) // heads_t[0] head_dim = inputs_q.size(2) // heads_t[0]
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)] # input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
...@@ -161,7 +159,7 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -161,7 +159,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim) input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:,:,0,:] keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:] values = input_lin_kv_results[:,:,1,:]
# Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!) # Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs. # The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads. # The tensor is declared before hand to properly slice out query, key, and value grads.
...@@ -169,7 +167,7 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -169,7 +167,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
queries_grads = torch.empty_like(queries) queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:,:,0,:] keys_grads = input_lin_kv_results_grads[:,:,0,:]
values_grads = input_lin_kv_results_grads[:,:,1,:] values_grads = input_lin_kv_results_grads[:,:,1,:]
# Output Linear GEMM - DGRAD # Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] # Input2: (weights) [ embed_dim, embed_dim ]
...@@ -182,13 +180,13 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -182,13 +180,13 @@ class EncdecAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_q*seqs, embed_dim ] # Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1), output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1) output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] : if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else : else:
output_bias_grads = None output_bias_grads = None
# Matmul2 - DGRAD1 # Matmul2 - DGRAD1
...@@ -215,14 +213,14 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -215,14 +213,14 @@ class EncdecAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1), queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2 # Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1), keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Q Linear GEMM - DGRAD # Input Q Linear GEMM - DGRAD
...@@ -253,11 +251,11 @@ class EncdecAttnFunc(torch.autograd.Function) : ...@@ -253,11 +251,11 @@ class EncdecAttnFunc(torch.autograd.Function) :
# output: [2*embed_dim, embed_dim] # output: [2*embed_dim, embed_dim]
# GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim) # GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2))) input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))
if use_biases_t[0] : if use_biases_t[0]:
input_bias_grads_q = torch.sum(queries_grads, 0) input_bias_grads_q = torch.sum(queries_grads, 0)
input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0) input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)
else : else:
input_bias_grads_q = None input_bias_grads_q = None
input_bias_grads_kv = None input_bias_grads_kv = None
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn import fast_encdec_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function) :
class FastEncdecAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -51,7 +47,7 @@ class FastEncdecAttnFunc(torch.autograd.Function) : ...@@ -51,7 +47,7 @@ class FastEncdecAttnFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
...@@ -6,18 +6,14 @@ ...@@ -6,18 +6,14 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn_norm_add import fast_encdec_multihead_attn_norm_add
class FastEncdecAttnNormAddFunc(torch.autograd.Function) :
class FastEncdecAttnNormAddFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -70,7 +66,7 @@ class FastEncdecAttnNormAddFunc(torch.autograd.Function) : ...@@ -70,7 +66,7 @@ class FastEncdecAttnNormAddFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn import fast_self_multihead_attn
class FastSelfAttnFunc(torch.autograd.Function) : class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -45,7 +41,7 @@ class FastSelfAttnFunc(torch.autograd.Function) : ...@@ -45,7 +41,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn_norm_add import fast_self_multihead_attn_norm_add
class FastSelfAttnNormAddFunc(torch.autograd.Function) :
class FastSelfAttnNormAddFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob) : def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = (pad_mask is not None)
...@@ -57,7 +53,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function) : ...@@ -57,7 +53,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function) :
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
heads_t, \ heads_t, \
matmul2_results, \ matmul2_results, \
dropout_results, \ dropout_results, \
......
...@@ -2,19 +2,20 @@ import torch ...@@ -2,19 +2,20 @@ import torch
from torch import nn from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
from .self_multihead_attn_func import self_attn_func from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) : def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True) out = F.dropout(x, p=prob, training=True)
out = residual + out out = residual + out
return out return out
class SelfMultiheadAttn(nn.Module): class SelfMultiheadAttn(nn.Module):
"""Multi-headed attention. """Multi-headed attention.
...@@ -43,12 +44,12 @@ class SelfMultiheadAttn(nn.Module): ...@@ -43,12 +44,12 @@ class SelfMultiheadAttn(nn.Module):
self.register_parameter('out_proj_bias', None) self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None self.in_proj_bias = None
self.out_proj_bias = None self.out_proj_bias = None
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None self.lyr_nrm = None
else : else:
self.register_parameter('lyr_norm_gamma_weights', None) self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None) self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None self.lyr_nrm_gamma_weights = None
...@@ -56,11 +57,11 @@ class SelfMultiheadAttn(nn.Module): ...@@ -56,11 +57,11 @@ class SelfMultiheadAttn(nn.Module):
self.lyr_nrm = torch.nn.LayerNorm(embed_dim) self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters() self.reset_parameters()
if self.include_norm_add : if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func
elif impl == 'default' : self.attn_func = self_attn_func elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
else : else:
if impl == 'fast' : self.attn_func = fast_self_attn_func if impl == 'fast' : self.attn_func = fast_self_attn_func
elif impl == 'default' : self.attn_func = self_attn_func elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) else : assert False, "Unsupported impl: {} !".format(impl)
...@@ -71,14 +72,14 @@ class SelfMultiheadAttn(nn.Module): ...@@ -71,14 +72,14 @@ class SelfMultiheadAttn(nn.Module):
if self.bias: if self.bias:
nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.) nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast':
nn.init.ones_(self.lyr_nrm_gamma_weights) nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights) nn.init.zeros_(self.lyr_nrm_beta_weights)
else : else:
self.lyr_nrm.reset_parameters() self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True) : def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for Self-attention can be implemented by passing in the same arguments for
...@@ -87,36 +88,36 @@ class SelfMultiheadAttn(nn.Module): ...@@ -87,36 +88,36 @@ class SelfMultiheadAttn(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
if key_padding_mask is not None : if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!" assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask mask = key_padding_mask
elif attn_mask is not None : elif attn_mask is not None:
mask = attn_mask mask = attn_mask
else : else:
mask = None mask = None
if self.include_norm_add : if self.include_norm_add:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights, self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else : else:
lyr_nrm_results = self.lyr_nrm(query) lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight, self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, self.in_proj_bias, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
if is_training : if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else : else:
outputs = outputs + query outputs = outputs + query
else : else:
if self.impl == 'fast' : if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout) self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else : else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight, self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias, self.in_proj_bias, self.out_proj_bias,
mask, self.dropout) mask, self.dropout)
......
import torch import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from torch.autograd.variable import Variable
class SelfAttnFunc(torch.autograd.Function) : class SelfAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs, def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights, input_weights, output_weights,
input_biases, output_biases, input_biases, output_biases,
mask, dropout_prob) : mask, dropout_prob):
use_biases_t = Variable(torch.tensor([input_biases is not None])) use_biases_t = torch.tensor([input_biases is not None])
heads_t = Variable(torch.tensor([heads])) heads_t = torch.tensor([heads])
scale_t = Variable(torch.tensor([scale])) scale_t = torch.tensor([scale])
dropout_prob_t = Variable(torch.tensor([dropout_prob])) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads head_dim = inputs.size(2) // heads
# Input Linear GEMM # Input Linear GEMM
# input1: (activations) [seql_q, seqs, embed_dim(1024)] # input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1]) # input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim*3] # output: [seql_q, seqs, embed_dim*3]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3) # GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)
if use_biases_t[0] : if use_biases_t[0]:
input_lin_results = torch.addmm(input_biases, input_lin_results = torch.addmm(input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0,1), input_weights.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1)) input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0)) input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)] # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
...@@ -39,10 +36,10 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -39,10 +36,10 @@ class SelfAttnFunc(torch.autograd.Function) :
queries = input_lin_results[:,:,0,:] queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:] keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:] values = input_lin_results[:,:,2,:]
# Matmul1 Batched GEMMs # Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification # The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of # baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation. # a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1) # Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1) # Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
...@@ -51,15 +48,15 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -51,15 +48,15 @@ class SelfAttnFunc(torch.autograd.Function) :
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda')) matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0]) matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None : if mask is not None:
# Self Attention Time Mask # Self Attention Time Mask
if use_time_mask : if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!" assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!" assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool) mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf')) matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask # Key Padding Mask
else : else:
batches,seql_q,seql_k = matmul1_results.size() batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads) seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k) matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
...@@ -70,12 +67,12 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -70,12 +67,12 @@ class SelfAttnFunc(torch.autograd.Function) :
softmax_results = F.softmax(matmul1_results, dim=-1) softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference # Dropout - is not executed for inference
if is_training : if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0])) dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else : else:
dropout_results = softmax_results dropout_results = softmax_results
dropout_mask = null_tensor dropout_mask = null_tensor
# Matmul2 Batched GEMMs # Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output. # The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified, # Given that pytorch cannot currently perform autograd with an output tensor specified,
...@@ -87,18 +84,18 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -87,18 +84,18 @@ class SelfAttnFunc(torch.autograd.Function) :
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0) matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results) matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2)) matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
# Output Linear GEMM # Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1) # Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim ) # GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] : if use_biases_t[0]:
outputs = torch.addmm(output_biases, outputs = torch.addmm(output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0,1), output_weights.transpose(0,1),
beta=1., alpha=1.) beta=1., alpha=1.)
else : else:
outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1)) outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0)) outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))
...@@ -114,11 +111,11 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -114,11 +111,11 @@ class SelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t) dropout_prob_t)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads) : def backward(ctx, output_grads):
use_biases_t, \ use_biases_t, \
heads_t, \ heads_t, \
scale_t, \ scale_t, \
...@@ -131,9 +128,9 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -131,9 +128,9 @@ class SelfAttnFunc(torch.autograd.Function) :
output_weights, \ output_weights, \
dropout_mask, \ dropout_mask, \
dropout_prob_t = ctx.saved_tensors dropout_prob_t = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0] head_dim = inputs.size(2) // heads_t[0]
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!) # Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM # Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)] # input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
...@@ -142,7 +139,7 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -142,7 +139,7 @@ class SelfAttnFunc(torch.autograd.Function) :
queries = input_lin_results[:,:,0,:] queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:] keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:] values = input_lin_results[:,:,2,:]
# Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!) # Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs. # The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads. # The tensor is declared before hand to properly slice out query, key, and value grads.
...@@ -150,7 +147,7 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -150,7 +147,7 @@ class SelfAttnFunc(torch.autograd.Function) :
queries_grads = input_lin_results_grads[:,:,0,:] queries_grads = input_lin_results_grads[:,:,0,:]
keys_grads = input_lin_results_grads[:,:,1,:] keys_grads = input_lin_results_grads[:,:,1,:]
values_grads = input_lin_results_grads[:,:,2,:] values_grads = input_lin_results_grads[:,:,2,:]
# Output Linear GEMM - DGRAD # Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim] # Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] # Input2: (weights) [ embed_dim, embed_dim ]
...@@ -163,13 +160,13 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -163,13 +160,13 @@ class SelfAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_q*seqs, embed_dim ] # Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ] # Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim ) # GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1), output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2))) matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1) output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] : if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0) output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else : else:
output_bias_grads = None output_bias_grads = None
# Matmul2 - DGRAD1 # Matmul2 - DGRAD1
...@@ -196,14 +193,14 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -196,14 +193,14 @@ class SelfAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1) # Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim ) # GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1), queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2 # Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2) # Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1) # Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1) # Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim ) # GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1), keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0]) out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Linear GEMM - DGRAD # Input Linear GEMM - DGRAD
...@@ -221,9 +218,9 @@ class SelfAttnFunc(torch.autograd.Function) : ...@@ -221,9 +218,9 @@ class SelfAttnFunc(torch.autograd.Function) :
# GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim) # GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2))) input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))
if use_biases_t[0] : if use_biases_t[0]:
input_bias_grads = torch.sum(input_lin_results_grads, 0) input_bias_grads = torch.sum(input_lin_results_grads, 0)
else : else:
input_bias_grads = None input_bias_grads = None
return None, None, None, None, \ return None, None, None, None, \
......
...@@ -205,7 +205,6 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -205,7 +205,6 @@ if "--fast_multihead_attn" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None: if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.") raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else: else:
import subprocess
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append( ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn', CUDAExtension(name='fast_self_multihead_attn',
...@@ -260,6 +259,70 @@ if "--fast_multihead_attn" in sys.argv: ...@@ -260,6 +259,70 @@ if "--fast_multihead_attn" in sys.argv:
'--expt-extended-lambda', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros})) '--use_fast_math'] + version_dependent_macros}))
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
setup( setup(
name='apex', name='apex',
version='0.1', version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment