Unverified Commit 1733946a authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Change to Multihead Attention to allow Batched GEMMs larger than 64K. (#728)

* Adding C++ Multihead Attention implementation to contrib.

* Add reference test that at least works for forward.

* Remove CublasLt support from multihead attention.

* Add new Python version of self attention.

* Update python model of MHA with backward pass.

* Fixed Output Linear connection in MHA.

* Clean up compiles and add documentation to PySelfAttention.

* Add Encdec Python version of multihead attention.  Cleanup files.

* Tests for self and encdec multihead attention.

* Add reference pytorch implementation of attention with norm and add.

* Add cutlass branch definition.

* Add cutlass download to compile.

* Add norm/add tests.

* Add biases to pytorch python versions.

* Add tests and fix issues with python version of attention masking.

* Create README.md

* Update README.md

* Update README.md

* Update perf test parameters.

* Update README.md

* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update README.md

* Update README.md

* Fix matmul1 output tensor size.  Fix tests that missed issue.

* Allow for Z dimensions of 64K and greater on batched GEMMs.

* remove redundant imports

* general cleanup, remove deprecated or unused functions
parent 50338df6
......@@ -100,9 +100,48 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params));
// batchCount in cutlass batched GEMM kernels maps to gridDim.z, which is limited to 16 bits.
// To implement batched GEMM with larger batch size, we fragment it into
// smaller batched GEMMs of gridDim.z <= 64k
long batchesLeft = batchCount;
long iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
do {
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f TotalBatches: %ld iterBatchCount %ld\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta, batchesLeft, iterBatchCount);
int result = params.initialize(
m, // M dimension for each batch
n, // N dimension for each batch
k, // K dimension for each batch
alpha, // scalar alpha
a,
lda,
strideA, // distance in memory between the first element of neighboring batch
b,
ldb,
strideB, // distance in memory between the first element of neighboring batch
beta, // scalar beta
c, // source matrix C
ldc,
strideC, // distance in memory between the first element of neighboring batch
c, // destination matrix C (may be different memory than source C matrix)
ldc,
strideC, // distance in memory between the first element of neighboring batch
iterBatchCount
);
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params));
// Update batched GEMM params based on completed work
batchesLeft = batchesLeft - iterBatchCount;
a += iterBatchCount * strideA;
b += iterBatchCount * strideB;
c += iterBatchCount * strideC;;
iterBatchCount = std::min(batchesLeft, static_cast<long>((1 << 16) - 1));
} while(batchesLeft > 0);
}
void gemm_switch_fp32accum(THCState *state, char transa, char transb, long m, long n, long k,
......
......@@ -2,19 +2,20 @@ import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) :
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
class EncdecMultiheadAttn(nn.Module):
"""Multi-headed attention.
......@@ -46,12 +47,12 @@ class EncdecMultiheadAttn(nn.Module):
self.in_proj_bias_q = None
self.in_proj_bias_kv = None
self.out_proj_bias = None
if self.include_norm_add :
if impl == 'fast' :
if self.include_norm_add:
if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
else :
else:
self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None
......@@ -59,11 +60,11 @@ class EncdecMultiheadAttn(nn.Module):
self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add :
if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func
elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
else :
else:
if impl == 'fast' : self.attn_func = fast_encdec_attn_func
elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
......@@ -76,14 +77,14 @@ class EncdecMultiheadAttn(nn.Module):
nn.init.constant_(self.in_proj_bias_q, 0.)
nn.init.constant_(self.in_proj_bias_kv, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add :
if self.include_norm_add:
if self.impl == 'fast' :
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else :
else:
self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True) :
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
......@@ -93,37 +94,37 @@ class EncdecMultiheadAttn(nn.Module):
batch x src_len, where padding elements are indicated by 1s.
"""
if key_padding_mask is not None :
if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None :
elif attn_mask is not None:
mask = attn_mask
else :
else:
mask = None
if self.include_norm_add :
if self.impl == 'fast' :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
if self.include_norm_add:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else :
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout)
if is_training :
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else :
outputs = outputs + query
else :
if self.impl == 'fast' :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
else:
outputs = outputs + query
else:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
else :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout)
return outputs,None
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
class EncdecAttnFunc(torch.autograd.Function) :
class EncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
input_biases_q, input_biases_kv, output_biases,
mask, dropout_prob) :
use_biases_t = Variable(torch.tensor([input_biases_q is not None]))
heads_t = Variable(torch.tensor([heads]))
scale_t = Variable(torch.tensor([scale]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
input_biases_q, input_biases_kv, output_biases,
mask, dropout_prob):
use_biases_t = torch.tensor([input_biases_q is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads
# Input Linear GEMM Q
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
if use_biases_t[0] :
if use_biases_t[0]:
input_lin_q_results = torch.addmm(input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0,1),
beta=1., alpha=1.)
else :
else:
input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))
input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))
# Input Linear GEMM KV
......@@ -35,15 +33,15 @@ class EncdecAttnFunc(torch.autograd.Function) :
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)] (transpose [0,1])
# output: [seql_k, seqs, embed_dim*2]
# GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)
if use_biases_t[0] :
input_lin_kv_results = torch.addmm(input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
if use_biases_t[0]:
input_lin_kv_results = torch.addmm(input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0,1),
beta=1., alpha=1.)
else :
else:
input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
......@@ -52,7 +50,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
......@@ -64,15 +62,15 @@ class EncdecAttnFunc(torch.autograd.Function) :
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None :
if mask is not None:
# Self Attention Time Mask
if use_time_mask :
if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask
else :
else:
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
......@@ -83,12 +81,12 @@ class EncdecAttnFunc(torch.autograd.Function) :
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training :
if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else :
else:
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
......@@ -100,18 +98,18 @@ class EncdecAttnFunc(torch.autograd.Function) :
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] :
if use_biases_t[0]:
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
else :
else:
outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))
......@@ -130,11 +128,11 @@ class EncdecAttnFunc(torch.autograd.Function) :
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
scale_t, \
......@@ -150,9 +148,9 @@ class EncdecAttnFunc(torch.autograd.Function) :
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
head_dim = inputs_q.size(2) // heads_t[0]
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
......@@ -161,7 +159,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
# Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
......@@ -169,7 +167,7 @@ class EncdecAttnFunc(torch.autograd.Function) :
queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:,:,0,:]
values_grads = input_lin_kv_results_grads[:,:,1,:]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
......@@ -182,13 +180,13 @@ class EncdecAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] :
if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else :
else:
output_bias_grads = None
# Matmul2 - DGRAD1
......@@ -215,14 +213,14 @@ class EncdecAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Q Linear GEMM - DGRAD
......@@ -253,11 +251,11 @@ class EncdecAttnFunc(torch.autograd.Function) :
# output: [2*embed_dim, embed_dim]
# GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))
if use_biases_t[0] :
if use_biases_t[0]:
input_bias_grads_q = torch.sum(queries_grads, 0)
input_bias_grads_kv = torch.sum(input_lin_kv_results_grads, 0)
else :
else:
input_bias_grads_q = None
input_bias_grads_kv = None
......
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function) :
class FastEncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
......@@ -51,7 +47,7 @@ class FastEncdecAttnFunc(torch.autograd.Function) :
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
......
......@@ -6,18 +6,14 @@
# can be found in the PATENTS file in the same directory.
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_encdec_multihead_attn_norm_add
class FastEncdecAttnNormAddFunc(torch.autograd.Function) :
class FastEncdecAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
......@@ -70,7 +66,7 @@ class FastEncdecAttnNormAddFunc(torch.autograd.Function) :
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
......
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn
class FastSelfAttnFunc(torch.autograd.Function) :
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
......@@ -45,7 +41,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
......
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
import fast_self_multihead_attn_norm_add
class FastSelfAttnNormAddFunc(torch.autograd.Function) :
class FastSelfAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob) :
heads_t = Variable(torch.tensor([heads]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
......@@ -57,7 +53,7 @@ class FastSelfAttnNormAddFunc(torch.autograd.Function) :
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
......
......@@ -2,19 +2,20 @@ import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training) :
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
out = F.dropout(x, p=prob, training=True)
out = residual + out
return out
class SelfMultiheadAttn(nn.Module):
"""Multi-headed attention.
......@@ -43,12 +44,12 @@ class SelfMultiheadAttn(nn.Module):
self.register_parameter('out_proj_bias', None)
self.in_proj_bias = None
self.out_proj_bias = None
if self.include_norm_add :
if impl == 'fast' :
if self.include_norm_add:
if impl == 'fast':
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
else :
else:
self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None
......@@ -56,11 +57,11 @@ class SelfMultiheadAttn(nn.Module):
self.lyr_nrm = torch.nn.LayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add :
if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
else :
else:
if impl == 'fast' : self.attn_func = fast_self_attn_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
......@@ -71,14 +72,14 @@ class SelfMultiheadAttn(nn.Module):
if self.bias:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
if self.include_norm_add :
if self.impl == 'fast' :
if self.include_norm_add:
if self.impl == 'fast':
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else :
else:
self.lyr_nrm.reset_parameters()
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True) :
def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
......@@ -87,36 +88,36 @@ class SelfMultiheadAttn(nn.Module):
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if key_padding_mask is not None :
if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None :
elif attn_mask is not None:
mask = attn_mask
else :
else:
mask = None
if self.include_norm_add :
if self.impl == 'fast' :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
if self.include_norm_add:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else :
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight,
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
mask, self.dropout)
if is_training :
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else :
outputs = outputs + query
else :
if self.impl == 'fast' :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
else:
outputs = outputs + query
else:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.in_proj_weight, self.out_proj_weight, mask, self.dropout)
else :
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight,
else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
self.in_proj_weight, self.out_proj_weight,
self.in_proj_bias, self.out_proj_bias,
mask, self.dropout)
......
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd.variable import Variable
class SelfAttnFunc(torch.autograd.Function) :
class SelfAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights,
input_biases, output_biases,
mask, dropout_prob) :
use_biases_t = Variable(torch.tensor([input_biases is not None]))
heads_t = Variable(torch.tensor([heads]))
scale_t = Variable(torch.tensor([scale]))
dropout_prob_t = Variable(torch.tensor([dropout_prob]))
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights,
input_biases, output_biases,
mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads
# Input Linear GEMM
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)] (transpose [0,1])
# output: [seql_q, seqs, embed_dim*3]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)
if use_biases_t[0] :
input_lin_results = torch.addmm(input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
if use_biases_t[0]:
input_lin_results = torch.addmm(input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0,1),
beta=1., alpha=1.)
else :
else:
input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
......@@ -39,10 +36,10 @@ class SelfAttnFunc(torch.autograd.Function) :
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
......@@ -51,15 +48,15 @@ class SelfAttnFunc(torch.autograd.Function) :
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
if mask is not None :
if mask is not None:
# Self Attention Time Mask
if use_time_mask :
if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
# Key Padding Mask
else :
else:
batches,seql_q,seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
......@@ -70,12 +67,12 @@ class SelfAttnFunc(torch.autograd.Function) :
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training :
if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
else :
else:
dropout_results = softmax_results
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
# Given that pytorch cannot currently perform autograd with an output tensor specified,
......@@ -87,18 +84,18 @@ class SelfAttnFunc(torch.autograd.Function) :
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ] transpose(0,1)
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0] :
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
if use_biases_t[0]:
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
else :
else:
outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))
outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))
......@@ -114,11 +111,11 @@ class SelfAttnFunc(torch.autograd.Function) :
output_weights, \
dropout_mask, \
dropout_prob_t)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads) :
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
scale_t, \
......@@ -131,9 +128,9 @@ class SelfAttnFunc(torch.autograd.Function) :
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0]
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
......@@ -142,7 +139,7 @@ class SelfAttnFunc(torch.autograd.Function) :
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
# Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
......@@ -150,7 +147,7 @@ class SelfAttnFunc(torch.autograd.Function) :
queries_grads = input_lin_results_grads[:,:,0,:]
keys_grads = input_lin_results_grads[:,:,1,:]
values_grads = input_lin_results_grads[:,:,2,:]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
......@@ -163,13 +160,13 @@ class SelfAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)
if use_biases_t[0] :
if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
else :
else:
output_bias_grads = None
# Matmul2 - DGRAD1
......@@ -196,14 +193,14 @@ class SelfAttnFunc(torch.autograd.Function) :
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
# Input Linear GEMM - DGRAD
......@@ -221,9 +218,9 @@ class SelfAttnFunc(torch.autograd.Function) :
# GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))
if use_biases_t[0] :
if use_biases_t[0]:
input_bias_grads = torch.sum(input_lin_results_grads, 0)
else :
else:
input_bias_grads = None
return None, None, None, None, \
......
......@@ -205,7 +205,6 @@ if "--fast_multihead_attn" in sys.argv:
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
import subprocess
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
......@@ -260,6 +259,70 @@ if "--fast_multihead_attn" in sys.argv:
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
if "--fast_multihead_attn" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension
sys.argv.remove("--fast_multihead_attn")
from torch.utils.cpp_extension import BuildExtension
cmdclass['build_ext'] = BuildExtension
if torch.utils.cpp_extension.CUDA_HOME is None:
raise RuntimeError("--fast_multihead_attn was requested, but nvcc was not found. Are you sure your environment has nvcc available? If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, only images whose names contain 'devel' will provide nvcc.")
else:
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-I./apex/contrib/csrc/multihead_attn/cutlass/',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros}))
setup(
name='apex',
version='0.1',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment