Commit 93f91cde authored by Kexin Yu's avatar Kexin Yu
Browse files

Merge remote-tracking branch 'upstream/master'

parents 33082d2b 80b90b9d
[submodule "apex/contrib/csrc/multihead_attn/cutlass"]
path = apex/contrib/csrc/multihead_attn/cutlass
url = https://github.com/NVIDIA/cutlass.git
branch = v1.2.0
...@@ -92,6 +92,14 @@ def lazy_init_with_master_weights(self): ...@@ -92,6 +92,14 @@ def lazy_init_with_master_weights(self):
def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None): def post_backward_models_are_masters(scaler, params, stashed_grads, scale_override=None):
grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0 grads_have_scale, stashed_have_scale, out_scale = scaler.loss_scale(), 1.0, 1.0
# not much to do if scale == 1.0 and static scaling
if scaler.loss_scale() == 1.0 and not scaler.dynamic:
# Clear the stash.
for i in range(len(stashed_grads)):
stashed_grads[i] = None
return
if scale_override is not None: if scale_override is not None:
grads_have_scale, stashed_have_scale, out_scale = scale_override grads_have_scale, stashed_have_scale, out_scale = scale_override
......
...@@ -63,7 +63,8 @@ FP32_FUNCS = [ ...@@ -63,7 +63,8 @@ FP32_FUNCS = [
'binary_cross_entropy_with_logits', 'binary_cross_entropy_with_logits',
'smooth_l1_loss', 'smooth_l1_loss',
'soft_margin_loss', 'soft_margin_loss',
'triplet_margin_loss' 'triplet_margin_loss',
'ctc_loss'
] ]
BANNED_FUNCS = [ BANNED_FUNCS = [
......
Subproject commit ed2ed4d667ce95e1371bd62db32b6a114e774336
#include <ATen/ATen.h>
#include <ATen/CUDAGenerator.h>
#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h>
#include <THC/THCGeneral.h>
const int UNROLL = 4;
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
rand.x = rand.x < p;
rand.y = rand.y < p;
rand.z = rand.z < p;
rand.w = rand.w < p;
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = inputs[li];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = src[ii]*static_cast<scalar_t>((&rand.x)[ii]*pinv);
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
__syncthreads();
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_dropout_add_kernel(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p,
std::pair<uint64_t, uint64_t> seeds
)
{
accscalar_t pinv = accscalar_t(1)/p;
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(
seeds.first,
idx,
seeds.second,
&state);
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
rand.x = rand.x < p;
rand.y = rand.y < p;
rand.z = rand.z < p;
rand.w = rand.w < p;
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = inputs[li];
add_src[ii] = add_inputs[li];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
accscalar_t int1 = static_cast<accscalar_t>((&rand.x)[ii]) * static_cast<accscalar_t>(src[ii]);
accscalar_t int2 = int1 * static_cast<accscalar_t>(pinv);
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int2);
mask[li] = (uint8_t)(&rand.x)[ii];
}
}
__syncthreads();
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_add_kernel( scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL) {
scalar_t src[UNROLL];
scalar_t add_src[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = inputs[li];
add_src[ii] = add_inputs[li];
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = src[ii] + add_src[ii];
}
}
__syncthreads();
}
}
template<typename scalar_t,
typename accscalar_t,
typename IndexType
>
__global__ void apex_masked_scale_kernel(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;
IndexType rounded_size = ((totalElements - 1)/(blockDim.x * gridDim.x * UNROLL)+1) * blockDim.x * gridDim.x * UNROLL;
for (IndexType linearIndex = idx;
linearIndex < rounded_size;
linearIndex += gridDim.x * blockDim.x*UNROLL)
{
scalar_t src[UNROLL];
scalar_t msk[UNROLL];
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
src[ii] = static_cast<scalar_t>(inputs[li]);
msk[ii] = static_cast<scalar_t>(mask[li]);
}
}
for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) {
outputs[li] = static_cast<scalar_t>(src[ii]*static_cast<scalar_t>(scale)) * msk[ii];
}
}
}
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_fused_dropout_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
}
apex_fused_dropout_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, p, rng_engine_inputs);
THCudaCheck(cudaGetLastError());
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_dropout_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
uint8_t *mask,
IndexType totalElements,
accscalar_t p)
{
auto gen = at::cuda::detail::getDefaultCUDAGenerator();
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs;
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset);
}
apex_dropout_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, mask, totalElements, p, rng_engine_inputs);
THCudaCheck(cudaGetLastError());
}
template <
typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_add_cuda(scalar_t const *inputs,
scalar_t const *add_inputs,
scalar_t *outputs,
IndexType totalElements
)
{
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
apex_add_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, add_inputs, outputs, totalElements);
THCudaCheck(cudaGetLastError());
}
template<typename scalar_t,
typename accscalar_t,
typename IndexType
>
void apex_masked_scale_cuda(scalar_t const *inputs,
scalar_t *outputs,
uint8_t const *mask,
IndexType totalElements,
accscalar_t scale
)
{
int block_size = 256;
dim3 dim_block(block_size);
dim3 grid((totalElements + block_size -1)/block_size);
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
apex_masked_scale_kernel<scalar_t, accscalar_t, IndexType><<<grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(inputs, outputs, mask, totalElements, scale);
THCudaCheck(cudaGetLastError());
}
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace encdec {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace encdec
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec::cublas_gemmex::fwd, "Encdec Multihead Attention Forward.");
m.def("backward", &multihead_attn::encdec::cublas_gemmex::bwd, "Encdec Multihead Attention Backward.");
}
This diff is collapsed.
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace encdec_norm_add {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_q_results,
torch::Tensor const& input_lin_kv_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs_q,
torch::Tensor const& inputs_kv,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights_q,
torch::Tensor const& input_weights_kv,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_q_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_kv_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs_q.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs_kv.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights_q.dim() == 2, "expected 2D tensor");
AT_ASSERTM(input_weights_kv.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_q_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_kv_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(inputs_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_q.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights_kv.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::encdec_norm_add::cublas_gemmex::fwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::encdec_norm_add::cublas_gemmex::bwd, "Encdec Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
This diff is collapsed.
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self::cublas_gemmex::fwd, "Self Multihead Attention Forward.");
m.def("backward", &multihead_attn::self::cublas_gemmex::bwd, "Self Multihead Attention Backward.");
}
This diff is collapsed.
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self_norm_add {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& lyr_nrm_results,
torch::Tensor const& lyr_nrm_mean,
torch::Tensor const& lyr_nrm_invvar,
torch::Tensor const& inputs,
torch::Tensor const& lyr_nrm_gamma_weights,
torch::Tensor const& lyr_nrm_beta_weights,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
torch::Tensor const& dropout_add_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_mean.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_invvar.dim() == 1, "expected 1D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(lyr_nrm_gamma_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(lyr_nrm_beta_weights.dim() == 1, "expected 1D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_add_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_mean.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(lyr_nrm_invvar.type().scalarType() == at::ScalarType::Float, "Only FLOAT is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_gamma_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(lyr_nrm_beta_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
AT_ASSERTM(dropout_add_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_norm_add::cublas_gemmex::fwd, "Self Multihead Attention Plus Layer Norm and Residual Add Forward.");
m.def("backward", &multihead_attn::self_norm_add::cublas_gemmex::bwd, "Self Multihead Attention Plus Layer Norm and Residual Add Backward.");
}
This diff is collapsed.
This diff is collapsed.
import torch
import torch.nn.functional as F
import argparse
from apex.contrib.multihead_attn import SelfMultiheadAttn
from apex.contrib.multihead_attn import EncdecMultiheadAttn
parser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')
parser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')
parser.add_argument('--num-seqs-start', default=10, type=int, help='Start Range of Number of Sequences')
parser.add_argument('--num-seqs-stop', default=120, type=int, help='Stop Range of Number of Sequences')
parser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')
parser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')
parser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')
parser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')
parser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')
parser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--biases', action='store_true', help='Execute multihead attention with Linear Biases.')
args = parser.parse_args()
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.cuda.set_device(0)
torch.manual_seed(111)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(111)
attn_layers = []
for idx in range(0, args.layers) :
if args.encdec_attn :
if args.ref :
attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=False, impl='default'))
else :
attn_layers.append(EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))
else :
if args.native :
attn_layers.append(torch.nn.MultiheadAttention(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases))
elif args.ref :
attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='default'))
else :
attn_layers.append(SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=0.1, bias=args.biases, include_norm_add=args.norm_add, impl='fast'))
attn_layers[idx].cuda()
attn_layers[idx].half()
if not args.native :
attn_layers[idx].reset_parameters()
start_evt_fwd = []
start_evt_bwd = []
stop_evt_bwd = []
for recorded_trial in range(0, args.trials) :
start_evt_fwd.append(torch.cuda.Event(enable_timing=True))
start_evt_bwd.append(torch.cuda.Event(enable_timing=True))
stop_evt_bwd.append(torch.cuda.Event(enable_timing=True))
for sequences in range(args.num_seqs_start, args.num_seqs_stop + args.num_seqs_inc, args.num_seqs_inc) :
inputs = torch.randn(args.seq_length, sequences, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
grads = torch.randn_like(inputs)
for trial in range(0, args.trials + args.warmup_trials) :
layer_inputs = inputs
evt_idx = trial - args.warmup_trials
if evt_idx >= 0 :
start_evt_fwd[evt_idx].record()
for lyr_idx in range(0, args.layers) :
if args.native :
outputs,_ = attn_layers[lyr_idx].forward(layer_inputs,
layer_inputs,
layer_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None)
else :
outputs,_ = attn_layers[lyr_idx].forward(layer_inputs,
layer_inputs,
layer_inputs,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
layer_inputs = outputs
if evt_idx >= 0 :
start_evt_bwd[evt_idx].record()
if not args.fwd :
layer_inputs.backward(grads)
if evt_idx >= 0 :
stop_evt_bwd[evt_idx].record()
torch.cuda.synchronize()
elapsed_time_fwd = 0.0
elapsed_time_bwd = 0.0
for evt_idx in range(0, args.trials) :
elapsed_time_fwd += start_evt_fwd[evt_idx].elapsed_time(start_evt_bwd[evt_idx])
elapsed_time_bwd += start_evt_bwd[evt_idx].elapsed_time(stop_evt_bwd[evt_idx])
print("[ {} Attn {} ]Total Tokens: {:4d} Sequences: {:3d} Sequence Length: {:3d} Fwd Time / Layer: {:.3f} ms Bwd Time / Layer: {:.3f} ms".format(
'Encdec' if args.encdec_attn else 'Self', \
'Norm&Add' if args.norm_add else '', \
sequences*args.seq_length, \
sequences, \
args.seq_length, \
elapsed_time_fwd / ( args.trials * args.layers ), \
elapsed_time_bwd / ( args.trials * args.layers )))
# Fast Multihead Attention
This implementation has two main features :
* A C++ implementation to avoid the CPU overheads of Pytorch found with smaller batch sizes.
* The removal of all copies and transposes found in standard implementations of Multihead Attention.
| | Python Version | C++ Version |
| :----------------------------------------- | :------------: | :---------: |
| Layer Norm and Residual Add Variant | X | X |
| Includes Linear Biases | X | |
| Reduces CPU Overheads | | X |
| Fuses masking with Softmax | | X |
| Removes Transposes and Copies | X | X |
| Includes Self and Encoder/Decoder Variants | X | X |
## How to Instantiate
`SelfMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`
`EncdecMultiheadAttn(` _hidden dim_, _heads_, _dropout=prob_, _bias=bool_, _include_norm_add=bool_, _impl='fast'_ `)`
`impl` has two options:
* `fast` uses C++ Version
* `default` uses Python Version
## Instructions to build on Linux
```
$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" ./
```
## Try Performance Tests Yourself!
Perf test script is found here!
```
cd contrib/examples/multihead_attn
```
#### Fast Multihead Attention
```
python perf_test_multihead_attn.py --ref
```
#### Fast Multihead Attention with C++ Implementation
```
python perf_test_multihead_attn.py
```
#### Compare with `torch.nn.MultiheadAttn`
```
python perf_test_multihead_attn.py --native
```
#### Test your own range!
```
python perf_test_multihead_attn.py --seq-length 64 --num-seqs-start 10 --num-seqs-stop 120 --num-seqs-inc 5
```
## Performance Comparisons
* Performance was measured with 64 token sequence lengths on an NVIDIA TitanV card.
* Time is measured across multiple layers to simulate an in model scenario.
![Multihead Attention Forward](MHA_fwd.png)
![Multihead Attention Backward](MHA_bwd.png)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment