Unverified Commit 7ec8ed67 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Faster `--fast_multihead_attn` build (#1245)

* merge .so files

* odr

* fix build

* update import

* apply psf/black with max line length of 120

* update

* fix

* update

* build fixed again but undefined symbol again

* fix 2, still layer norm grad is undefined

* remove unused cpp files

* without layer_norm.cuh, import works

* import fast_multihead_attn works...

but why? Was unnecessary `#include "layer_norm.cuh"` was the culprit
causing .shared objects not to be able to link `HostApplyLayerNorm` and
`HostLayerNormGradient`?

* clean up layer norm
parent ed94d0bb
......@@ -11,10 +11,10 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "dropout.h"
#include "layer_norm.h"
#include "softmax.h"
#include "strided_batched_gemm.h"
#include "dropout.cuh"
#include "layer_norm.cuh"
#include "softmax.cuh"
#include "strided_batched_gemm.cuh"
namespace multihead_attn {
namespace self_norm_add {
......@@ -363,7 +363,7 @@ std::vector<torch::Tensor> bwd_cuda(
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>(
static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<half const *>(output_grads.data_ptr()),
static_cast<const half *>(output_grads.data_ptr()),
static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
static_cast<int>(batches), // n1
......
#pragma once
#include "philox.h"
#include "philox.cuh"
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h>
......@@ -15,6 +15,14 @@ namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value,
const uint8_t *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst,
const Datatype *additive_mask);
template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst,
const __half *src) {
......@@ -43,10 +51,6 @@ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
*((half2 *)dst) = *((half2 *)src);
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value,
const uint8_t *src);
template <>
__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value,
const uint8_t *src) {
......@@ -54,14 +58,13 @@ __device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value,
*dst = value;
}
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst,
const Datatype *additive_mask);
template <>
__device__ __inline__ void
apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask;
}
template <>
__device__ __inline__ void
apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
......@@ -70,7 +73,6 @@ apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
*(dst + 2) += *(additive_mask + 2);
*(dst + 3) += *(additive_mask + 3);
}
} // namespace
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward
......@@ -3132,3 +3134,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
}
return false;
}
} // namespace
#pragma once
#include <iostream>
#include <vector>
......@@ -14,6 +15,7 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
namespace {
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't')
return CUBLAS_OP_T;
......@@ -47,6 +49,7 @@ void CublasStridedBatchedGemm(
CUDA_R_16F, (int)ldc, strideC, (int)batchCount, CUDA_R_32F, algo));
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
}
} // namespace
template <cutlass::MatrixLayout::Kind A_LAYOUT,
cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C>
......@@ -153,6 +156,7 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
} while (batchesLeft > 0);
}
namespace {
void gemm_switch_fp32accum(char transa, char transb, long m,
long n, long k, float alpha, const half *a, long lda,
long strideA, const half *b, long ldb, long strideB,
......@@ -632,3 +636,4 @@ void HgemmStridedBatched(char transa, char transb, long m,
b, ldb, strideB, beta, c, ldc, strideC, batchCount);
}
} // namespace
......@@ -9,7 +9,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/macros/Macros.h>
#include "philox.h"
#include "philox.cuh"
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize.
......
......@@ -5,16 +5,17 @@ from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, '_jit_set_profiling_executor') :
if hasattr(torch._C, "_jit_set_profiling_executor"):
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, '_jit_set_profiling_mode') :
if hasattr(torch._C, "_jit_set_profiling_mode"):
torch._C._jit_set_profiling_mode(False)
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
......@@ -28,7 +29,8 @@ class EncdecMultiheadAttn(nn.Module):
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_add=False, impl="fast"):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
......@@ -38,43 +40,49 @@ class EncdecMultiheadAttn(nn.Module):
self.bias = bias
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim**-0.5
self.scaling = self.head_dim ** -0.5
self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_kv = Parameter(torch.Tensor(2*embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_kv = Parameter(torch.Tensor(2 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_kv = Parameter(torch.Tensor(2*embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
assert impl != "fast", "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
self.register_parameter('in_proj_bias_q', None)
self.register_parameter('in_proj_bias_kv', None)
self.in_proj_bias_q = None
self.register_parameter("in_proj_bias_q", None)
self.register_parameter("in_proj_bias_kv", None)
self.in_proj_bias_q = None
self.in_proj_bias_kv = None
self.out_proj_bias = None
self.out_proj_bias = None
if self.include_norm_add:
if impl == 'fast':
if impl == "fast":
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None)
self.register_parameter("lyr_norm_gamma_weights", None)
self.register_parameter("lyr_norm_beta_weights", None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func
elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
if impl == "fast":
self.attn_func = fast_encdec_attn_norm_add_func
elif impl == "default":
self.attn_func = encdec_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
else:
if impl == 'fast' : self.attn_func = fast_encdec_attn_func
elif impl == 'default' : self.attn_func = encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
if impl == "fast":
self.attn_func = fast_encdec_attn_func
elif impl == "default":
self.attn_func = encdec_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q)
......@@ -85,11 +93,11 @@ class EncdecMultiheadAttn(nn.Module):
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.)
nn.init.constant_(self.in_proj_bias_kv, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
nn.init.constant_(self.in_proj_bias_q, 0.0)
nn.init.constant_(self.in_proj_bias_kv, 0.0)
nn.init.constant_(self.out_proj_bias, 0.0)
if self.include_norm_add:
if self.impl == 'fast' :
if self.impl == "fast":
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else:
......@@ -106,7 +114,7 @@ class EncdecMultiheadAttn(nn.Module):
"""
if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None:
mask = attn_mask
......@@ -114,28 +122,73 @@ class EncdecMultiheadAttn(nn.Module):
mask = None
if self.include_norm_add:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
key,
self.lyr_nrm_gamma_weights,
self.lyr_nrm_beta_weights,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
mask,
self.dropout,
)
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout)
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
lyr_nrm_results,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
self.in_proj_bias_q,
self.in_proj_bias_kv,
self.out_proj_bias,
mask,
self.dropout,
)
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else:
outputs = outputs + query
else:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout)
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
mask,
self.dropout,
)
else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias,
mask, self.dropout)
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
query,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
self.in_proj_bias_q,
self.in_proj_bias_kv,
self.out_proj_bias,
mask,
self.dropout,
)
return outputs,None
return outputs, None
......@@ -4,16 +4,29 @@ import torch.nn.functional as F
class EncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs_q, inputs_kv,
input_weights_q, input_weights_kv, output_weights,
input_biases_q, input_biases_kv, output_biases,
mask, dropout_prob):
use_biases_t = torch.tensor([input_biases_q is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
def forward(
ctx,
use_time_mask,
is_training,
heads,
scale,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
input_biases_q,
input_biases_kv,
output_biases,
mask,
dropout_prob,
):
use_biases_t = torch.tensor([input_biases_q is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads
null_tensor = torch.tensor([])
head_dim = inputs_q.size(2) // heads
# Input Linear GEMM Q
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
......@@ -21,12 +34,17 @@ class EncdecAttnFunc(torch.autograd.Function):
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
if use_biases_t[0]:
input_lin_q_results = torch.addmm(input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0,1),
beta=1., alpha=1.)
input_lin_q_results = torch.addmm(
input_biases_q,
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
input_weights_q.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
input_lin_q_results = torch.mm(inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0,1))
input_lin_q_results = torch.mm(
inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), input_weights_q.transpose(0, 1)
)
input_lin_q_results = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1), input_weights_q.size(0))
# Input Linear GEMM KV
# input1: (activations) [seql_k, seqs, embed_dim(1024)]
......@@ -34,58 +52,73 @@ class EncdecAttnFunc(torch.autograd.Function):
# output: [seql_k, seqs, embed_dim*2]
# GEMM: ( (seql_k*seqs) x embed_dim ) x ( embed_dim x embed_dim*2 ) = (seql_k*seqs x embed_dim*2)
if use_biases_t[0]:
input_lin_kv_results = torch.addmm(input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0,1),
beta=1., alpha=1.)
input_lin_kv_results = torch.addmm(
input_biases_kv,
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
input_lin_kv_results = torch.mm(inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)), input_weights_kv.transpose(0,1))
input_lin_kv_results = torch.mm(
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
input_weights_kv.transpose(0, 1),
)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1), input_weights_kv.size(0))
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads, head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads, 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads, head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads, 2, head_dim)
keys = input_lin_kv_results[:, :, 0, :]
values = input_lin_kv_results[:, :, 1, :]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# baddbmm is used to apply the scale parameter via the Batched GEMM's alpha parameter instead of
# a separate elementwise operation.
# Input1: (Queries) [seql_q, seqs*heads, head_dim] tranpose(0,1)
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
matmul1_results = torch.empty(
(queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype, device=torch.device("cuda")
)
matmul1_results = torch.baddbmm(
matmul1_results,
queries.transpose(0, 1),
keys.transpose(0, 1).transpose(1, 2),
out=matmul1_results,
beta=0.0,
alpha=scale_t[0],
)
if mask is not None:
# Self Attention Time Mask
if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
assert len(mask.size()) == 2, "Timing mask is not 2D!"
assert mask.size(0) == mask.size(1), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
matmul1_results = matmul1_results.masked_fill_(mask, float("-inf"))
# Key Padding Mask
else:
batches,seql_q,seql_k = matmul1_results.size()
batches, seql_q, seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf"))
matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1.0 - dropout_prob_t[0]))
else:
dropout_results = softmax_results
dropout_mask = null_tensor
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
......@@ -95,9 +128,15 @@ class EncdecAttnFunc(torch.autograd.Function):
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
matmul2_results = torch.empty(
(dropout_results.size(1), dropout_results.size(0), values.size(2)),
dtype=dropout_results.dtype,
device=torch.device("cuda"),
).transpose(1, 0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results)
matmul2_results = (
matmul2_results.transpose(0, 1).contiguous().view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
)
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
......@@ -105,87 +144,105 @@ class EncdecAttnFunc(torch.autograd.Function):
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0]:
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
outputs = torch.addmm(
output_biases,
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
outputs = torch.mm(matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)), output_weights.transpose(0,1))
outputs = torch.mm(
matmul2_results.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2)),
output_weights.transpose(0, 1),
)
outputs = outputs.view(inputs_q.size(0), inputs_q.size(1), output_weights.size(0))
ctx.save_for_backward(use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t)
ctx.save_for_backward(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
head_dim = inputs_q.size(2) // heads_t[0]
head_dim = inputs_q.size(2) // heads_t[0]
# Slice out k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_kv_results: [seql_k, seqs, heads(16), 2, head_dim(64)]
# input_lin_kv_results: [seql_k, batches=seqs*heads, 2, head_dim]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1)*heads_t[0], head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1)*heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:,:,0,:]
values = input_lin_kv_results[:,:,1,:]
queries = input_lin_q_results.view(inputs_q.size(0), inputs_q.size(1) * heads_t[0], head_dim)
input_lin_kv_results = input_lin_kv_results.view(inputs_kv.size(0), inputs_kv.size(1) * heads_t[0], 2, head_dim)
keys = input_lin_kv_results[:, :, 0, :]
values = input_lin_kv_results[:, :, 1, :]
# Slice out k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_kv_results_grads = torch.empty_like(input_lin_kv_results)
queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:,:,0,:]
values_grads = input_lin_kv_results_grads[:,:,1,:]
queries_grads = torch.empty_like(queries)
keys_grads = input_lin_kv_results_grads[:, :, 0, :]
values_grads = input_lin_kv_results_grads[:, :, 1, :]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
output_lin_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights
)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1)*heads_t[0], head_dim).transpose(0,1)
output_weight_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)),
)
output_lin_grads = output_lin_grads.view(
output_grads.size(0), output_grads.size(1) * heads_t[0], head_dim
).transpose(0, 1)
if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
output_bias_grads = torch.sum(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0
)
else:
output_bias_grads = None
......@@ -194,63 +251,82 @@ class EncdecAttnFunc(torch.autograd.Function):
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
queries_grads = torch.baddbmm(
queries_grads.transpose(0, 1),
softmax_grads,
keys.transpose(0, 1),
out=queries_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
keys_grads = torch.baddbmm(
keys_grads.transpose(0, 1),
softmax_grads.transpose(1, 2),
queries.transpose(0, 1),
out=keys_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Input Q Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, embed_dim(1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)]
# input2: (weights) [embed_dim (1024), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
queries_grads = queries_grads.transpose(0,1).view(inputs_q.size(0)*inputs_q.size(1), heads_t[0]*head_dim)
queries_grads = queries_grads.transpose(0, 1).view(inputs_q.size(0) * inputs_q.size(1), heads_t[0] * head_dim)
input_q_grads = torch.mm(queries_grads, input_weights_q)
input_q_grads = input_q_grads.view(inputs_q.size(0), inputs_q.size(1), inputs_q.size(2))
# Input KV Linear GEMM - DGRAD
# input1: (data grads) [seql_k, seqs, 2*embed_dim(2048)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)]
# input2: (weights) [embed_dim*2 (2048), embed_dim (1024)]
# output: [seql_k, seqs, embed_dim]
# GEMM: ( (seql_k*seqs) x 2*embed_dim ) x ( 2*embed_dim x embed_dim ) = (seql_k*seqs x embed_dim)
input_lin_kv_results_grads = input_lin_kv_results_grads.view(inputs_kv.size(0)*inputs_kv.size(1), heads_t[0]*2*head_dim)
input_lin_kv_results_grads = input_lin_kv_results_grads.view(
inputs_kv.size(0) * inputs_kv.size(1), heads_t[0] * 2 * head_dim
)
input_kv_grads = torch.mm(input_lin_kv_results_grads, input_weights_kv)
input_kv_grads = input_kv_grads.view(inputs_kv.size(0), inputs_kv.size(1), inputs_kv.size(2))
# Input Q Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, embed_dim(1024)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [embed_dim, embed_dim]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (embed_dim x embed_dim)
input_weight_q_grads = torch.mm(queries_grads.transpose(0,1), inputs_q.view(inputs_q.size(0)*inputs_q.size(1), inputs_q.size(2)))
input_weight_q_grads = torch.mm(
queries_grads.transpose(0, 1), inputs_q.view(inputs_q.size(0) * inputs_q.size(1), inputs_q.size(2))
)
# Input KV Linear GEMM - WGRAD
# input1: (data grads) [seql_k*seqs, 2*embed_dim(2048)]
# input2: (activations) [seql_k*seqs, embed_dim(1024)]
# input2: (activations) [seql_k*seqs, embed_dim(1024)]
# output: [2*embed_dim, embed_dim]
# GEMM: ( 2*embed_dim x seql_k*seqs ) x ( seql_k*seqs x embed_dim ) = (2*embed_dim x embed_dim)
input_weight_kv_grads = torch.mm(input_lin_kv_results_grads.transpose(0,1), inputs_kv.view(inputs_kv.size(0)*inputs_kv.size(1), inputs_kv.size(2)))
input_weight_kv_grads = torch.mm(
input_lin_kv_results_grads.transpose(0, 1),
inputs_kv.view(inputs_kv.size(0) * inputs_kv.size(1), inputs_kv.size(2)),
)
if use_biases_t[0]:
input_bias_grads_q = torch.sum(queries_grads, 0)
......@@ -259,10 +335,22 @@ class EncdecAttnFunc(torch.autograd.Function):
input_bias_grads_q = None
input_bias_grads_kv = None
return None, None, None, None, \
input_q_grads, input_kv_grads, \
input_weight_q_grads, input_weight_kv_grads, output_weight_grads, \
input_bias_grads_q, input_bias_grads_kv, output_bias_grads, \
None, None
return (
None,
None,
None,
None,
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
input_bias_grads_q,
input_bias_grads_kv,
output_bias_grads,
None,
None,
)
encdec_attn_func = EncdecAttnFunc.apply
import torch
import fast_encdec_multihead_attn
import fast_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
input_lin_q_results, \
input_lin_kv_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_encdec_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
(
input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.encdec_multihead_attn_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t)
ctx.save_for_backward(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
) = fast_multihead_attn.encdec_multihead_attn_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
input_q_grads, \
input_kv_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads = \
fast_encdec_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return (
None,
None,
None,
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
None,
None,
)
return None, None, None, input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads, None, None
fast_encdec_attn_func = FastEncdecAttnFunc.apply
......@@ -6,125 +6,154 @@
# can be found in the PATENTS file in the same directory.
import torch
import fast_encdec_multihead_attn_norm_add
import fast_multihead_attn
class FastEncdecAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
input_lin_q_results, \
input_lin_kv_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
dropout_add_mask, \
outputs = \
fast_encdec_multihead_attn_norm_add.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
(
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_q_results,
input_lin_kv_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs,
) = fast_multihead_attn.encdec_multihead_attn_norm_add_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t)
ctx.save_for_backward(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t = ctx.saved_tensors
(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_q_grads,
input_kv_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
) = fast_multihead_attn.encdec_multihead_attn_norm_add_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t[0],
)
input_q_grads, \
input_kv_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads = \
fast_encdec_multihead_attn_norm_add.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_q_results, \
input_lin_kv_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t[0])
# import pdb; pdb.set_trace()
return (
None,
None,
None,
input_q_grads,
input_kv_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
None,
None,
)
#import pdb; pdb.set_trace()
return None, None, None, \
input_q_grads, \
input_kv_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads, \
None, None
fast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply
import torch
import fast_self_multihead_attn
import fast_self_multihead_attn_bias
import fast_self_multihead_attn_bias_additive_mask
class FastSelfAttnFunc(torch.autograd.Function) :
import fast_multihead_attn
class FastSelfAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, input_weights, output_weights, input_biases, output_biases, pad_mask, mask_additive, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
pad_mask,
mask_additive,
dropout_prob,
):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
mask_additive_t= torch.tensor([mask_additive])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
mask_additive_t = torch.tensor([mask_additive])
if use_biases_t[0]:
if not mask_additive:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
(
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.self_attn_bias_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn_bias.forward() \
ctx.save_for_backward(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
softmax_results,
null_tensor,
null_tensor,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
else:
input_lin_results, \
bmm1_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn_bias_additive_mask.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
input_biases, \
output_biases, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
null_tensor, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
(
input_lin_results,
bmm1_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.self_attn_bias_additive_mask_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn_bias_additive_mask.forward( \
ctx.save_for_backward(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
null_tensor,
bmm1_results,
pad_mask,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
else:
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
outputs = \
fast_self_multihead_attn.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward(use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
null_tensor, \
null_tensor, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
(
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
outputs,
) = fast_multihead_attn.self_attn_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn.forward( \
ctx.save_for_backward(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
softmax_results,
null_tensor,
null_tensor,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
bmm1_results, \
pad_mask, \
mask_additive_t, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
(
use_biases_t,
heads_t,
matmul2_results,
dropout_results,
softmax_results,
bmm1_results,
pad_mask,
mask_additive_t,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
if use_biases_t[0]:
if not mask_additive_t[0]:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
(
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
) = fast_multihead_attn.self_attn_bias_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn_bias.backward( \
else:
input_grads, \
input_weight_grads, \
output_weight_grads, \
input_bias_grads, \
output_bias_grads = \
fast_self_multihead_attn_bias_additive_mask.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
bmm1_results, \
pad_mask, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
(
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
) = fast_multihead_attn.self_attn_bias_additive_mask_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
bmm1_results,
pad_mask,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn_bias_additive_mask.backward( \
else:
input_bias_grads = None
input_bias_grads = None
output_bias_grads = None
input_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_grads, input_weight_grads, output_weight_grads,input_bias_grads, output_bias_grads, None, None, None
input_grads, input_weight_grads, output_weight_grads = fast_multihead_attn.self_attn_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn.backward( \
return (
None,
None,
None,
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
None,
None,
None,
)
fast_self_attn_func = FastSelfAttnFunc.apply
import torch
import fast_self_multihead_attn_norm_add
import fast_multihead_attn
class FastSelfAttnNormAddFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights, output_weights, pad_mask, dropout_prob):
heads_t = torch.tensor([heads])
def forward(
ctx,
use_time_mask,
is_training,
heads,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
input_lin_results, \
softmax_results, \
dropout_results, \
dropout_mask, \
matmul2_results, \
dropout_add_mask, \
outputs = \
fast_self_multihead_attn_norm_add.forward( \
use_mask, \
use_time_mask, \
is_training, \
heads, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
(
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
input_lin_results,
softmax_results,
dropout_results,
dropout_mask,
matmul2_results,
dropout_add_mask,
outputs,
) = fast_multihead_attn.self_attn_norm_add_forward(
use_mask,
use_time_mask,
is_training,
heads,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
pad_mask if use_mask else null_tensor,
dropout_prob,
)
# fast_self_multihead_attn_norm_add.forward( \
ctx.save_for_backward(heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t)
ctx.save_for_backward(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
heads_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t = ctx.saved_tensors
(
heads_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads,
) = fast_multihead_attn.self_attn_norm_add_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t[0],
)
# fast_self_multihead_attn_norm_add.backward( \
input_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_grads, \
output_weight_grads = \
fast_self_multihead_attn_norm_add.backward( \
heads_t[0], \
output_grads, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t[0])
return (
None,
None,
None,
input_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_grads,
output_weight_grads,
None,
None,
)
return None, None, None, \
input_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_grads, \
output_weight_grads, \
None, None
fast_self_attn_norm_add_func = FastSelfAttnNormAddFunc.apply
import torch
import fast_mask_softmax_dropout
import fast_additive_mask_softmax_dropout
import fast_multihead_attn
class MaskSoftmaxDropout(torch.autograd.Function) :
class MaskSoftmaxDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):
heads_t = torch.tensor([heads])
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
use_mask = (pad_mask is not None)
use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive])
null_tensor = torch.tensor([])
use_mask = pad_mask is not None
use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive])
if mask_additive:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_additive_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
dropout_results, dropout_mask, softmax_results = fast_multihead_attn.additive_mask_softmax_dropout_forward(
use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob
)
# fast_additive_mask_softmax_dropout.forward( \
else:
dropout_results, \
dropout_mask, \
softmax_results = \
fast_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
dropout_results, dropout_mask, softmax_results = fast_multihead_attn.mask_softmax_dropout_forward(
use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob
)
# fast_mask_softmax_dropout.forward( \
ctx.save_for_backward(
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask if use_mask else null_tensor, \
mask_additive_t, \
dropout_prob_t)
use_mask_t,
heads_t,
softmax_results,
dropout_mask,
pad_mask if use_mask else null_tensor,
mask_additive_t,
dropout_prob_t,
)
return dropout_results.detach()
@staticmethod
def backward(ctx, output_grads):
use_mask_t, \
heads_t, \
softmax_results, \
dropout_mask, \
pad_mask, \
mask_additive_t, \
dropout_prob_t = ctx.saved_tensors
(
use_mask_t,
heads_t,
softmax_results,
dropout_mask,
pad_mask,
mask_additive_t,
dropout_prob_t,
) = ctx.saved_tensors
if mask_additive_t[0]:
input_grads = \
fast_additive_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
dropout_prob_t[0])
input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward(
use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, dropout_prob_t[0]
)
# fast_additive_mask_softmax_dropout.backward( \
else:
input_grads = \
fast_mask_softmax_dropout.backward( \
use_mask_t[0], \
heads_t[0], \
output_grads, \
softmax_results, \
dropout_mask, \
pad_mask, \
dropout_prob_t[0])
input_grads = fast_multihead_attn.mask_softmax_dropout_backward(
use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, pad_mask, dropout_prob_t[0]
)
# fast_mask_softmax_dropout.backward( \
return None, None, input_grads, None, None, None
fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply
......@@ -5,16 +5,17 @@ from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func
from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, '_jit_set_profiling_executor') :
if hasattr(torch._C, "_jit_set_profiling_executor"):
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, '_jit_set_profiling_mode') :
if hasattr(torch._C, "_jit_set_profiling_mode"):
torch._C._jit_set_profiling_mode(False)
@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor
......@@ -28,7 +29,18 @@ class SelfMultiheadAttn(nn.Module):
See "Attention Is All You Need" for more details.
"""
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast', separate_qkv_params=False, mask_additive=False):
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=False,
include_norm_add=False,
impl="fast",
separate_qkv_params=False,
mask_additive=False,
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
......@@ -38,61 +50,69 @@ class SelfMultiheadAttn(nn.Module):
self.bias = bias
self.include_norm_add = include_norm_add
self.impl = impl
self.scaling = self.head_dim**-0.5
self.scaling = self.head_dim ** -0.5
self.separate_qkv_params = separate_qkv_params
self.mask_additive = mask_additive
if mask_additive:
assert self.include_norm_add == False, "additive mask not supported with layer norm"
assert impl == 'default' or (impl == 'fast' and bias), "additive mask not supported for fast mode without bias"
assert impl == "default" or (
impl == "fast" and bias
), "additive mask not supported for fast mode without bias"
if separate_qkv_params:
self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.q_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.k_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
self.v_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
else:
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias:
if separate_qkv_params:
self.q_bias = Parameter(torch.Tensor(embed_dim))
self.k_bias = Parameter(torch.Tensor(embed_dim))
self.v_bias = Parameter(torch.Tensor(embed_dim))
self.q_bias = Parameter(torch.Tensor(embed_dim))
self.k_bias = Parameter(torch.Tensor(embed_dim))
self.v_bias = Parameter(torch.Tensor(embed_dim))
else:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else:
if separate_qkv_params:
self.register_parameter('q_bias', None)
self.register_parameter('k_bias', None)
self.register_parameter('v_bias', None)
self.register_parameter("q_bias", None)
self.register_parameter("k_bias", None)
self.register_parameter("v_bias", None)
self.q_bias = None
self.k_bias = None
self.v_bias = None
else:
self.register_parameter('in_proj_bias', None)
self.register_parameter("in_proj_bias", None)
self.in_proj_bias = None
self.register_parameter('out_proj_bias', None)
self.register_parameter("out_proj_bias", None)
self.out_proj_bias = None
if self.include_norm_add:
if impl == 'fast':
if impl == "fast":
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None
else:
self.register_parameter('lyr_norm_gamma_weights', None)
self.register_parameter('lyr_norm_beta_weights', None)
self.register_parameter("lyr_norm_gamma_weights", None)
self.register_parameter("lyr_norm_beta_weights", None)
self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm_beta_weights = None
self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters()
if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_self_attn_norm_add_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
if impl == "fast":
self.attn_func = fast_self_attn_norm_add_func
elif impl == "default":
self.attn_func = self_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
else:
if impl == 'fast' : self.attn_func = fast_self_attn_func
elif impl == 'default' : self.attn_func = self_attn_func
else : assert False, "Unsupported impl: {} !".format(impl)
if impl == "fast":
self.attn_func = fast_self_attn_func
elif impl == "default":
self.attn_func = self_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self):
if self.separate_qkv_params:
......@@ -108,14 +128,14 @@ class SelfMultiheadAttn(nn.Module):
nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias:
if self.separate_qkv_params:
nn.init.constant_(self.q_bias, 0.)
nn.init.constant_(self.k_bias, 0.)
nn.init.constant_(self.v_bias, 0.)
nn.init.constant_(self.q_bias, 0.0)
nn.init.constant_(self.k_bias, 0.0)
nn.init.constant_(self.v_bias, 0.0)
else:
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj_bias, 0.)
nn.init.constant_(self.in_proj_bias, 0.0)
nn.init.constant_(self.out_proj_bias, 0.0)
if self.include_norm_add:
if self.impl == 'fast':
if self.impl == "fast":
nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights)
else:
......@@ -131,18 +151,40 @@ class SelfMultiheadAttn(nn.Module):
batch x src_len, where padding elements are indicated by 1s.
"""
if self.separate_qkv_params:
input_weights = torch.cat([self.q_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.k_weight.view(self.num_heads,1,self.head_dim,self.embed_dim), self.v_weight.view(self.num_heads,1,self.head_dim,self.embed_dim)], dim=1).reshape(3*self.embed_dim,self.embed_dim).contiguous()
else:
input_weights = (
torch.cat(
[
self.q_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),
self.k_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),
self.v_weight.view(self.num_heads, 1, self.head_dim, self.embed_dim),
],
dim=1,
)
.reshape(3 * self.embed_dim, self.embed_dim)
.contiguous()
)
else:
input_weights = self.in_proj_weight
if self.bias:
if self.separate_qkv_params:
input_bias = torch.cat([self.q_bias.view(self.num_heads,1,self.head_dim), self.k_bias.view(self.num_heads,1,self.head_dim), self.v_bias.view(self.num_heads,1,self.head_dim)],dim=1).reshape(3*self.embed_dim).contiguous()
input_bias = (
torch.cat(
[
self.q_bias.view(self.num_heads, 1, self.head_dim),
self.k_bias.view(self.num_heads, 1, self.head_dim),
self.v_bias.view(self.num_heads, 1, self.head_dim),
],
dim=1,
)
.reshape(3 * self.embed_dim)
.contiguous()
)
else:
input_bias = self.in_proj_bias
else:
input_bias=None
input_bias = None
if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!"
assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask
elif attn_mask is not None:
assert self.mask_additive == False, "additive mask not supported for time mask"
......@@ -151,28 +193,68 @@ class SelfMultiheadAttn(nn.Module):
mask = None
if self.include_norm_add:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights,
input_weights, self.out_proj_weight, mask, self.dropout)
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
self.lyr_nrm_gamma_weights,
self.lyr_nrm_beta_weights,
input_weights,
self.out_proj_weight,
mask,
self.dropout,
)
else:
lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results,
input_weights, self.out_proj_weight,
input_bias, self.out_proj_bias,
mask, self.mask_additive, self.dropout)
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
lyr_nrm_results,
input_weights,
self.out_proj_weight,
input_bias,
self.out_proj_bias,
mask,
self.mask_additive,
self.dropout,
)
if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else:
outputs = outputs + query
else:
if self.impl == 'fast':
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query,
input_weights, self.out_proj_weight, input_bias, self.out_proj_bias, mask, self.mask_additive, self.dropout)
if self.impl == "fast":
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
query,
input_weights,
self.out_proj_weight,
input_bias,
self.out_proj_bias,
mask,
self.mask_additive,
self.dropout,
)
else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query,
input_weights, self.out_proj_weight,
input_bias, self.out_proj_bias,
mask, self.mask_additive, self.dropout)
outputs = self.attn_func(
attn_mask is not None,
is_training,
self.num_heads,
self.scaling,
query,
input_weights,
self.out_proj_weight,
input_bias,
self.out_proj_bias,
mask,
self.mask_additive,
self.dropout,
)
return outputs,None
return outputs, None
import torch
import torch.nn.functional as F
class SelfAttnFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, use_time_mask, is_training, heads, scale, inputs,
input_weights, output_weights,
input_biases, output_biases,
mask, is_additive_mask, dropout_prob):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
def forward(
ctx,
use_time_mask,
is_training,
heads,
scale,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
mask,
is_additive_mask,
dropout_prob,
):
use_biases_t = torch.tensor([input_biases is not None])
heads_t = torch.tensor([heads])
scale_t = torch.tensor([scale])
dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads
null_tensor = torch.tensor([])
head_dim = inputs.size(2) // heads
# Input Linear GEMM
# input1: (activations) [seql_q, seqs, embed_dim(1024)]
......@@ -20,22 +32,27 @@ class SelfAttnFunc(torch.autograd.Function):
# output: [seql_q, seqs, embed_dim*3]
# GEMM: ( (seql_q*seqs) x embed_dim ) x ( embed_dim x embed_dim*3 ) = (seql_q*seqs x embed_dim*3)
if use_biases_t[0]:
input_lin_results = torch.addmm(input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0,1),
beta=1., alpha=1.)
input_lin_results = torch.addmm(
input_biases,
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
input_weights.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
input_lin_results = torch.mm(inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0,1))
input_lin_results = torch.mm(
inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2)), input_weights.transpose(0, 1)
)
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1), input_weights.size(0))
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads, 3, head_dim)
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads, 3, head_dim)
queries = input_lin_results[:, :, 0, :]
keys = input_lin_results[:, :, 1, :]
values = input_lin_results[:, :, 2, :]
# Matmul1 Batched GEMMs
# The output tensor is specified prior to the Batch GEMM because baddbmm requires its specification
......@@ -45,36 +62,45 @@ class SelfAttnFunc(torch.autograd.Function):
# Input2: (Keys) [seql_k, seqs*heads, head_dim] transpose(0,1)
# output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul1_results = torch.empty((queries.size(1),queries.size(0),keys.size(0)), dtype=queries.dtype, device=torch.device('cuda'))
matmul1_results = torch.baddbmm(matmul1_results, queries.transpose(0,1), keys.transpose(0,1).transpose(1,2), out=matmul1_results, beta=0.0, alpha=scale_t[0])
matmul1_results = torch.empty(
(queries.size(1), queries.size(0), keys.size(0)), dtype=queries.dtype, device=torch.device("cuda")
)
matmul1_results = torch.baddbmm(
matmul1_results,
queries.transpose(0, 1),
keys.transpose(0, 1).transpose(1, 2),
out=matmul1_results,
beta=0.0,
alpha=scale_t[0],
)
if mask is not None:
# Self Attention Time Mask
if use_time_mask:
assert (len(mask.size()) == 2), "Timing mask is not 2D!"
assert (mask.size(0) == mask.size(1)), "Sequence length should match!"
assert len(mask.size()) == 2, "Timing mask is not 2D!"
assert mask.size(0) == mask.size(1), "Sequence length should match!"
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask, float('-inf'))
matmul1_results = matmul1_results.masked_fill_(mask, float("-inf"))
# Key Padding Mask
else:
batches,seql_q,seql_k = matmul1_results.size()
batches, seql_q, seql_k = matmul1_results.size()
seqs = int(batches / heads)
matmul1_results = matmul1_results.view(seqs, heads, seql_q, seql_k)
if is_additive_mask:
matmul1_results = matmul1_results + mask.unsqueeze(1).unsqueeze(2)
else:
mask = mask.to(torch.bool)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float('-inf'))
matmul1_results = matmul1_results.view(seqs*heads, seql_q, seql_k)
matmul1_results = matmul1_results.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf"))
matmul1_results = matmul1_results.view(seqs * heads, seql_q, seql_k)
softmax_results = F.softmax(matmul1_results, dim=-1)
# Dropout - is not executed for inference
if is_training:
dropout_results,dropout_mask = torch._fused_dropout(softmax_results, p=(1.-dropout_prob_t[0]))
dropout_results, dropout_mask = torch._fused_dropout(softmax_results, p=(1.0 - dropout_prob_t[0]))
else:
dropout_results = softmax_results
dropout_mask = null_tensor
dropout_mask = null_tensor
# Matmul2 Batched GEMMs
# The output tensor specification is needed here to specify the non-standard output.
......@@ -84,9 +110,15 @@ class SelfAttnFunc(torch.autograd.Function):
# Input2: (values) [seql_v, seqs*heads, head_dim] transpose(0,1)
# Output: [seql_q, seqs*heads, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = (seql_q x head_dim)
matmul2_results = torch.empty((dropout_results.size(1), dropout_results.size(0), values.size(2)), dtype=dropout_results.dtype, device=torch.device('cuda')).transpose(1,0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0,1), out=matmul2_results)
matmul2_results = matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
matmul2_results = torch.empty(
(dropout_results.size(1), dropout_results.size(0), values.size(2)),
dtype=dropout_results.dtype,
device=torch.device("cuda"),
).transpose(1, 0)
matmul2_results = torch.bmm(dropout_results, values.transpose(0, 1), out=matmul2_results)
matmul2_results = (
matmul2_results.transpose(0, 1).contiguous().view(inputs.size(0), inputs.size(1), inputs.size(2))
)
# Output Linear GEMM
# Input1: (activations) [seql_q, seqs, embed_dim=heads*head_dim]
......@@ -94,81 +126,96 @@ class SelfAttnFunc(torch.autograd.Function):
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
if use_biases_t[0]:
outputs = torch.addmm(output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0,1),
beta=1., alpha=1.)
outputs = torch.addmm(
output_biases,
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)),
output_weights.transpose(0, 1),
beta=1.0,
alpha=1.0,
)
else:
outputs = torch.mm(matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0,1))
outputs = torch.mm(
matmul2_results.view(inputs.size(0) * inputs.size(1), inputs.size(2)), output_weights.transpose(0, 1)
)
outputs = outputs.view(inputs.size(0), inputs.size(1), output_weights.size(0))
ctx.save_for_backward(use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t)
ctx.save_for_backward(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
)
return outputs.detach()
@staticmethod
def backward(ctx, output_grads):
use_biases_t, \
heads_t, \
scale_t, \
matmul2_results, \
dropout_results, \
softmax_results, \
input_lin_results, \
inputs, \
input_weights, \
output_weights, \
dropout_mask, \
dropout_prob_t = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0]
(
use_biases_t,
heads_t,
scale_t,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
head_dim = inputs.size(2) // heads_t[0]
# Slice out q,k,v from one big Input Linear outuput (should only impact meta data, no copies!)
# Sequences and heads are combined to make the batch of the Batched GEMM
# input_lin_results: [seql_q, seqs, heads(16), 3, head_dim(64)]
# input_lin_results: [seql_q, batches=seqs*heads, 3, head_dim]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1)*heads_t[0], 3, head_dim)
queries = input_lin_results[:,:,0,:]
keys = input_lin_results[:,:,1,:]
values = input_lin_results[:,:,2,:]
input_lin_results = input_lin_results.view(inputs.size(0), inputs.size(1) * heads_t[0], 3, head_dim)
queries = input_lin_results[:, :, 0, :]
keys = input_lin_results[:, :, 1, :]
values = input_lin_results[:, :, 2, :]
# Slice out q,k,v from one big set of gradients entering the input linear's bprop (should only impact meta data, no copies!)
# The gradients are identical in size to the Input Linear outputs.
# The tensor is declared before hand to properly slice out query, key, and value grads.
input_lin_results_grads = torch.empty_like(input_lin_results)
queries_grads = input_lin_results_grads[:,:,0,:]
keys_grads = input_lin_results_grads[:,:,1,:]
values_grads = input_lin_results_grads[:,:,2,:]
queries_grads = input_lin_results_grads[:, :, 0, :]
keys_grads = input_lin_results_grads[:, :, 1, :]
values_grads = input_lin_results_grads[:, :, 2, :]
# Output Linear GEMM - DGRAD
# Input1: (data grads) [seql_q, seqs, embed_dim=heads*head_dim]
# Input2: (weights) [ embed_dim, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( seql_q*seqs x embed_dim ) x ( embed_dim x embed_dim ) = ( seql_q*seqs x embed_dim )
output_lin_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights)
output_lin_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), output_weights
)
output_lin_grads = output_lin_grads.view(output_grads.size(0), output_grads.size(1), output_weights.size(1))
# Output Linear GEMM - WGRAD
# Input1: (data grads) [seql_q*seqs, embed_dim=heads*head_dim] transpose(0,1)
# Input2: (activations) [seql_q*seqs, embed_dim ]
# Output: [ seql_q, seqs, embed_dim ]
# GEMM: ( embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = ( embed_dim x embed_dim )
output_weight_grads = torch.mm(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0,1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)))
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1)*heads_t[0], head_dim).transpose(0,1)
output_weight_grads = torch.mm(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)).transpose(0, 1),
matmul2_results.view(matmul2_results.size(0) * matmul2_results.size(1), matmul2_results.size(2)),
)
output_lin_grads = output_lin_grads.view(inputs.size(0), inputs.size(1) * heads_t[0], head_dim).transpose(0, 1)
if use_biases_t[0]:
output_bias_grads = torch.sum(output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0)
output_bias_grads = torch.sum(
output_grads.view(output_grads.size(0) * output_grads.size(1), output_grads.size(2)), 0
)
else:
output_bias_grads = None
......@@ -177,59 +224,84 @@ class SelfAttnFunc(torch.autograd.Function):
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0,1).transpose(1,2))
matmul2_dgrad1 = torch.bmm(output_lin_grads, values.transpose(0, 1).transpose(1, 2))
# Matmul2 - DGRAD2
# Input1: (data grads) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1).transpose(1,2)
# Output: [seqs*heads, seql_q, seql_k]
# GEMM: Per batch: ( seql_q x head_dim ) x ( head_dim x seql_k ) = ( seql_q x seql_k )
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
values_grads = torch.bmm(dropout_results.transpose(1, 2), output_lin_grads, out=values_grads.transpose(0, 1))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0 / (1.0 - dropout_prob_t[0]))
# Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
# Matmul1 - DGRAD1
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input1: (data grads) [seqs*heads, seql_q, seql_k]
# Input2: (activations) [seql_k, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_q, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_q x seql_k ) x ( seql_k x head_dim ) = ( seql_q x head_dim )
queries_grads = torch.baddbmm(queries_grads.transpose(0,1), softmax_grads, keys.transpose(0,1),
out=queries_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
queries_grads = torch.baddbmm(
queries_grads.transpose(0, 1),
softmax_grads,
keys.transpose(0, 1),
out=queries_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Matmul1 - DGRAD2
# Input1: (data grads) [seqs*heads, seql_q, seql_k] transpose(1,2)
# Input2: (activations) [seql_q, seqs*heads, head_dim] transpose(0,1)
# Output: [seqs*heads, seql_k, head_dim] transpose(0,1)
# GEMM: Per batch: ( seql_k x seql_q ) x ( seql_q x head_dim ) = ( seql_k x head_dim )
keys_grads = torch.baddbmm(keys_grads.transpose(0,1), softmax_grads.transpose(1,2), queries.transpose(0,1),
out=keys_grads.transpose(0,1), beta=0.0, alpha=scale_t[0])
keys_grads = torch.baddbmm(
keys_grads.transpose(0, 1),
softmax_grads.transpose(1, 2),
queries.transpose(0, 1),
out=keys_grads.transpose(0, 1),
beta=0.0,
alpha=scale_t[0],
)
# Input Linear GEMM - DGRAD
# input1: (data grads) [seql_q, seqs, 3*embed_dim(3072)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)]
# input2: (weights) [embed_dim*3 (3072), embed_dim (1024)]
# output: [seql_q, seqs, embed_dim]
# GEMM: ( (seql_q*seqs) x 3*embed_dim ) x ( 3*embed_dim x embed_dim ) = (seql_q*seqs x embed_dim)
input_lin_results_grads = input_lin_results_grads.view(inputs.size(0)*inputs.size(1), heads_t[0]*3*head_dim)
input_lin_results_grads = input_lin_results_grads.view(
inputs.size(0) * inputs.size(1), heads_t[0] * 3 * head_dim
)
input_grads = torch.mm(input_lin_results_grads, input_weights)
input_grads = input_grads.view(inputs.size(0), inputs.size(1), inputs.size(2))
# Input Linear GEMM - WGRAD
# input1: (data grads) [seql_q*seqs, 3*embed_dim(3072)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# input2: (activations) [seql_q*seqs, embed_dim(1024)]
# output: [3*embed_dim, embed_dim]
# GEMM: ( 3*embed_dim x seql_q*seqs ) x ( seql_q*seqs x embed_dim ) = (3*embed_dim x embed_dim)
input_weight_grads = torch.mm(input_lin_results_grads.transpose(0,1), inputs.view(inputs.size(0)*inputs.size(1), inputs.size(2)))
input_weight_grads = torch.mm(
input_lin_results_grads.transpose(0, 1), inputs.view(inputs.size(0) * inputs.size(1), inputs.size(2))
)
if use_biases_t[0]:
input_bias_grads = torch.sum(input_lin_results_grads, 0)
else:
input_bias_grads = None
return None, None, None, None, \
input_grads, \
input_weight_grads, output_weight_grads, \
input_bias_grads, output_bias_grads, \
None, None
return (
None,
None,
None,
None,
input_grads,
input_weight_grads,
output_weight_grads,
input_bias_grads,
output_bias_grads,
None,
None,
)
self_attn_func = SelfAttnFunc.apply
......@@ -385,112 +385,34 @@ if "--fast_multihead_attn" in sys.argv:
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
cc_flag.append('-gencode')
cc_flag.append('arch=compute_86,code=sm_86')
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
ext_modules.append(
CUDAExtension(name='fast_additive_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append(
CUDAExtension(name='fast_mask_softmax_dropout',
sources=['apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp',
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias_additive_mask',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_bias',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append(
CUDAExtension(name='fast_self_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
ext_modules.append(
CUDAExtension(name='fast_encdec_multihead_attn_norm_add',
sources=['apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp',
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'],
extra_compile_args={'cxx': ['-O3',] + version_dependent_macros + generator_flag,
'nvcc':['-O3',
'-gencode', 'arch=compute_70,code=sm_70',
'-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")]))
CUDAExtension(
name='fast_multihead_attn',
sources=[
'apex/contrib/csrc/multihead_attn/multihead_attn_frontend.cpp',
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu',
"apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu",
"apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu",
"apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu",
"apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu",
"apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu",
"apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu",
"apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu",
],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros + generator_flag,
'nvcc': [
'-O3', '-gencode', 'arch=compute_70,code=sm_70', '-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__', '--expt-relaxed-constexpr', '--expt-extended-lambda',
'--use_fast_math'] + version_dependent_macros + generator_flag + cc_flag,
},
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/multihead_attn/cutlass")],
)
)
if "--transducer" in sys.argv:
sys.argv.remove("--transducer")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment