Unverified Commit 5754fa7a authored by Kevin Stephano's avatar Kevin Stephano Committed by GitHub
Browse files

Fixes to Multihead Attention with LayerNorm and Dropout-Add (#860)

parent 6c2babf9
...@@ -42,10 +42,11 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -42,10 +42,11 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
linearIndex += gridDim.x * blockDim.x*UNROLL) { linearIndex += gridDim.x * blockDim.x*UNROLL) {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
rand.x = rand.x < p; rand.x = rand.x <= p;
rand.y = rand.y < p; rand.y = rand.y <= p;
rand.z = rand.z < p; rand.z = rand.z <= p;
rand.w = rand.w < p; rand.w = rand.w <= p;
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
...@@ -55,7 +56,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -55,7 +56,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = src[ii]*static_cast<scalar_t>((&rand.x)[ii]*pinv); outputs[li] = src[ii]*(&rand.x)[ii]*pinv;
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -94,10 +95,10 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -94,10 +95,10 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t add_src[UNROLL]; scalar_t add_src[UNROLL];
rand.x = rand.x < p; rand.x = rand.x <= p;
rand.y = rand.y < p; rand.y = rand.y <= p;
rand.z = rand.z < p; rand.z = rand.z <= p;
rand.w = rand.w < p; rand.w = rand.w <= p;
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
...@@ -108,9 +109,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -108,9 +109,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
accscalar_t int1 = static_cast<accscalar_t>((&rand.x)[ii]) * static_cast<accscalar_t>(src[ii]); accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
accscalar_t int2 = int1 * static_cast<accscalar_t>(pinv); outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int2);
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -182,7 +182,7 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs, ...@@ -182,7 +182,7 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = static_cast<scalar_t>(src[ii]*static_cast<scalar_t>(scale)) * msk[ii]; outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);
} }
} }
} }
......
...@@ -182,9 +182,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -182,9 +182,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -397,9 +397,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,9 +397,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
......
...@@ -204,9 +204,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -204,9 +204,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -257,18 +257,18 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -257,18 +257,18 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<half,float,uint32_t>( apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs_q.data_ptr()), static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens_q, total_tokens_q,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} else { } else {
apex_add_cuda<half,float,uint32_t>( apex_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs_q.data_ptr()), static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
total_tokens_q); total_tokens_q);
} }
...@@ -347,6 +347,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -347,6 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor dropout_add_grads = torch::empty_like(output_grads);
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
...@@ -369,9 +370,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,9 +370,9 @@ std::vector<torch::Tensor> bwd_cuda(
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<at::Half*>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),
total_tokens_q, total_tokens_q,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
...@@ -387,7 +388,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -387,7 +388,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -408,7 +409,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -408,7 +409,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -459,9 +460,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -459,9 +460,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
......
...@@ -153,9 +153,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -153,9 +153,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -200,7 +200,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -200,7 +200,6 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -357,9 +356,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -357,9 +356,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
...@@ -434,7 +433,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -434,7 +433,6 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad // Input Linear Wgrad
......
...@@ -176,9 +176,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -176,9 +176,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -224,23 +224,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -224,23 +224,22 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<half,float,uint32_t>( apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs.data_ptr()), static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens, total_tokens,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} else { } else {
apex_add_cuda<half,float,uint32_t>( apex_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs.data_ptr()), static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
total_tokens); total_tokens);
} }
...@@ -309,6 +308,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -309,6 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
torch::Tensor dropout_add_grads = torch::empty_like(output_grads);
torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); torch::Tensor output_lin_grads = torch::empty_like(matmul2_results);
torch::Tensor matmul2_grads = torch::empty_like(dropout_results); torch::Tensor matmul2_grads = torch::empty_like(dropout_results);
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
...@@ -330,9 +330,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -330,9 +330,9 @@ std::vector<torch::Tensor> bwd_cuda(
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<at::Half*>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),
total_tokens, total_tokens,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
...@@ -348,7 +348,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -348,7 +348,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -420,9 +420,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -420,9 +420,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
......
import torch
import torch.nn.functional as F
import argparse
from apex.contrib.multihead_attn import SelfMultiheadAttn
from apex.contrib.multihead_attn import EncdecMultiheadAttn
parser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')
parser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')
parser.add_argument('--num-seqs-start', default=5, type=int, help='Start Range of Number of Sequences')
parser.add_argument('--num-seqs-stop', default=80, type=int, help='Stop Range of Number of Sequences')
parser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')
parser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--seed-start', default=1, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--seed-end', default=100, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')
parser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')
parser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')
parser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')
parser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--eval', action='store_true', help='Inference only, no backward pass.')
args = parser.parse_args()
assert args.seq_length % 64 == 0, "Sequence Length should be a multiple of 64!"
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.cuda.set_device(0)
dropout_prob = 0.1
for seed in range(args.seed_start, args.seed_end+1) :
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
ref_layer = None
if args.encdec_attn :
ref_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')
else :
ref_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')
ref_layer.cuda()
ref_layer.half()
ref_layer.reset_parameters()
ref_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
ref_inputs_kv = None
if args.encdec_attn :
ref_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
ref_grads = torch.randn_like(ref_inputs)
ref_outputs,_ = ref_layer.forward(ref_inputs,
ref_inputs_kv,
ref_inputs_kv,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=(not args.eval))
ref_outputs.backward(ref_grads)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
tst_layer = None
if args.encdec_attn :
tst_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')
else:
tst_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')
tst_layer.cuda()
tst_layer.half()
tst_layer.reset_parameters()
tst_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
tst_inputs_kv = None
if args.encdec_attn :
tst_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
assert torch.equal(ref_inputs,tst_inputs), "ERROR: Inputs are different!"
tst_grads = torch.randn_like(tst_inputs)
tst_outputs,_ = tst_layer.forward(tst_inputs,
tst_inputs_kv,
tst_inputs_kv,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=(not args.eval))
tst_outputs.backward(tst_grads)
fwd_close = torch.equal(ref_outputs, tst_outputs)
bwd_close = torch.equal(ref_inputs.grad, tst_inputs.grad)
diff_fwd = ref_outputs - tst_outputs
diff_cnt_fwd = diff_fwd.ne(0.0).sum()
diff_accum_fwd = diff_fwd.abs().sum()
diff_bwd = ref_inputs.grad - tst_inputs.grad
diff_cnt_bwd = diff_bwd.ne(0.0).sum()
diff_accum_bwd = diff_bwd.abs().sum()
print(">>> Seed: ", seed, fwd_close, diff_cnt_fwd.item(), diff_accum_fwd.item(), bwd_close, diff_cnt_bwd.item(), diff_accum_bwd.item())
...@@ -6,7 +6,12 @@ import torch.nn.functional as F ...@@ -6,7 +6,12 @@ import torch.nn.functional as F
from .encdec_multihead_attn_func import encdec_attn_func from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, '_jit_set_profiling_executor') :
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, '_jit_set_profiling_mode') :
torch._C._jit_set_profiling_mode(False)
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training): def jit_dropout_add(x, residual, prob, is_training):
...@@ -57,9 +62,9 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -57,9 +62,9 @@ class EncdecMultiheadAttn(nn.Module):
self.register_parameter('lyr_norm_beta_weights', None) self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None self.lyr_nrm_beta_weights = None
self.lyr_nrm = torch.nn.LayerNorm(embed_dim) self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters() self.reset_parameters()
if self.include_norm_add: if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func
elif impl == 'default' : self.attn_func = encdec_attn_func elif impl == 'default' : self.attn_func = encdec_attn_func
......
...@@ -203,7 +203,7 @@ class EncdecAttnFunc(torch.autograd.Function): ...@@ -203,7 +203,7 @@ class EncdecAttnFunc(torch.autograd.Function):
values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1)) values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1))
# Mask and Scaling for Dropout (not a publically documented op) # Mask and Scaling for Dropout (not a publically documented op)
dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0]) dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0]))
# Softmax Grad (not a publically documented op) # Softmax Grad (not a publically documented op)
softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)
......
...@@ -6,7 +6,12 @@ import torch.nn.functional as F ...@@ -6,7 +6,12 @@ import torch.nn.functional as F
from .self_multihead_attn_func import self_attn_func from .self_multihead_attn_func import self_attn_func
from .fast_self_multihead_attn_func import fast_self_attn_func from .fast_self_multihead_attn_func import fast_self_attn_func
from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func from .fast_self_multihead_attn_norm_add_func import fast_self_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, '_jit_set_profiling_executor') :
torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, '_jit_set_profiling_mode') :
torch._C._jit_set_profiling_mode(False)
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training): def jit_dropout_add(x, residual, prob, is_training):
...@@ -75,7 +80,7 @@ class SelfMultiheadAttn(nn.Module): ...@@ -75,7 +80,7 @@ class SelfMultiheadAttn(nn.Module):
self.register_parameter('lyr_norm_beta_weights', None) self.register_parameter('lyr_norm_beta_weights', None)
self.lyr_nrm_gamma_weights = None self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None self.lyr_nrm_beta_weights = None
self.lyr_nrm = torch.nn.LayerNorm(embed_dim) self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters() self.reset_parameters()
if self.include_norm_add: if self.include_norm_add:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment