Unverified Commit 37989915 authored by Ashish Farmer's avatar Ashish Farmer Committed by GitHub
Browse files

Merge pull request #20 from lcskrishna/ifu_06052020

IFU_06_05_2020
parents b0c7d09f 097238f8
...@@ -157,10 +157,8 @@ A Python-only build omits: ...@@ -157,10 +157,8 @@ A Python-only build omits:
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`. - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
To enable PyProf support, you need to install the packages required by PyProf. To do so, add the "--pyprof" option at installation time: Pyprof support has been moved to its own [dedicated repository](https://github.com/NVIDIA/PyProf).
``` The codebase is deprecated in Apex and will be removed soon.
$ pip install -v --no-cache-dir --global-option="--pyprof" --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
### Windows support ### Windows support
Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
......
#include <torch/extension.h>
#include <cuda_fp16.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
dropout_prob
);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::additive_mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace additive_mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const half* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
softmax_success = dispatch_additive_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
//backward pass is completely in-place
return output_grads;
}
}
}
}
...@@ -42,10 +42,11 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -42,10 +42,11 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
linearIndex += gridDim.x * blockDim.x*UNROLL) { linearIndex += gridDim.x * blockDim.x*UNROLL) {
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
rand.x = rand.x < p; rand.x = rand.x <= p;
rand.y = rand.y < p; rand.y = rand.y <= p;
rand.z = rand.z < p; rand.z = rand.z <= p;
rand.w = rand.w < p; rand.w = rand.w <= p;
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
...@@ -55,7 +56,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs, ...@@ -55,7 +56,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = src[ii]*static_cast<scalar_t>((&rand.x)[ii]*pinv); outputs[li] = src[ii]*(&rand.x)[ii]*pinv;
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -94,10 +95,10 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -94,10 +95,10 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
float4 rand = curand_uniform4(&state); float4 rand = curand_uniform4(&state);
scalar_t src[UNROLL]; scalar_t src[UNROLL];
scalar_t add_src[UNROLL]; scalar_t add_src[UNROLL];
rand.x = rand.x < p; rand.x = rand.x <= p;
rand.y = rand.y < p; rand.y = rand.y <= p;
rand.z = rand.z < p; rand.z = rand.z <= p;
rand.w = rand.w < p; rand.w = rand.w <= p;
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
...@@ -108,9 +109,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs, ...@@ -108,9 +109,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
accscalar_t int1 = static_cast<accscalar_t>((&rand.x)[ii]) * static_cast<accscalar_t>(src[ii]); accscalar_t int1 = src[ii] * (&rand.x)[ii] * pinv;
accscalar_t int2 = int1 * static_cast<accscalar_t>(pinv); outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int1);
outputs[li] = static_cast<scalar_t>(static_cast<accscalar_t>(add_src[ii]) + int2);
mask[li] = (uint8_t)(&rand.x)[ii]; mask[li] = (uint8_t)(&rand.x)[ii];
} }
} }
...@@ -182,7 +182,7 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs, ...@@ -182,7 +182,7 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for (int ii = 0; ii < UNROLL; ii++) { for (int ii = 0; ii < UNROLL; ii++) {
IndexType li = linearIndex + blockDim.x * gridDim.x * ii; IndexType li = linearIndex + blockDim.x * gridDim.x * ii;
if (li < totalElements) { if (li < totalElements) {
outputs[li] = static_cast<scalar_t>(src[ii]*static_cast<scalar_t>(scale)) * msk[ii]; outputs[li] = static_cast<accscalar_t>(src[ii]) * scale * static_cast<accscalar_t>(msk[ii]);
} }
} }
} }
......
...@@ -182,9 +182,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -182,9 +182,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -397,9 +397,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -397,9 +397,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
......
...@@ -204,9 +204,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -204,9 +204,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -257,18 +257,18 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -257,18 +257,18 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<half,float,uint32_t>( apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs_q.data_ptr()), static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens_q, total_tokens_q,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} else { } else {
apex_add_cuda<half,float,uint32_t>( apex_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs_q.data_ptr()), static_cast<at::Half const*>(inputs_q.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
total_tokens_q); total_tokens_q);
} }
...@@ -347,6 +347,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -347,6 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv); torch::Tensor input_weight_kv_grads = torch::empty_like(input_weights_kv);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
at::Tensor dropout_add_grads = torch::empty_like(output_grads);
at::Tensor output_lin_grads = torch::empty_like(matmul2_results); at::Tensor output_lin_grads = torch::empty_like(matmul2_results);
at::Tensor matmul2_grads = torch::empty_like(dropout_results); at::Tensor matmul2_grads = torch::empty_like(dropout_results);
at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results); at::Tensor input_lin_q_output_grads = torch::empty_like(input_lin_q_results);
...@@ -369,9 +370,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,9 +370,9 @@ std::vector<torch::Tensor> bwd_cuda(
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<at::Half*>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),
total_tokens_q, total_tokens_q,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
...@@ -387,7 +388,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -387,7 +388,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -408,7 +409,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -408,7 +409,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -459,9 +460,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -459,9 +460,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
......
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
);
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool is_training,
int heads,
torch::Tensor const& input,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
is_training,
heads,
input,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
torch::Tensor bwd(
bool use_mask,
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
torch::Tensor const& padding_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
// AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
softmax_results,
dropout_mask,
use_mask ? static_cast<const uint8_t*>(padding_mask.data_ptr()) : nullptr,
dropout_prob
);
}
} // end namespace mask_softmax_dropout
} // end namespace fused_softmax
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::fused_softmax::mask_softmax_dropout::fwd, "Self Multihead Attention masked softmax dropout -- Forward.");
m.def("backward", &multihead_attn::fused_softmax::mask_softmax_dropout::bwd, "Self Multihead Attention masked softmax dropout -- Backward.");
}
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "softmax.h"
#include "dropout.h"
// symbol to be automatically resolved by PyTorch libs
extern THCState *state;
namespace multihead_attn {
namespace fused_softmax {
namespace mask_softmax_dropout {
std::vector<torch::Tensor> fwd_cuda(
bool is_training,
int heads,
torch::Tensor const& input,
const uint8_t* pad_mask,
float dropout_prob
)
{
const int attn_batches = input.size(0);
const int sequences = attn_batches / heads;
const int q_seq_len = input.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library code)
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
softmax_success = dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else {
softmax_success = dispatch_masked_softmax<half, half, float>(
reinterpret_cast<half*>(softmax_results_ptr),
reinterpret_cast<const half*>(input_ptr),
pad_mask,
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
attn_batches*q_seq_len/sequences);
}
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(softmax_results, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
// Matmul2
return {
dropout_results,
dropout_mask,
softmax_results
};
}
torch::Tensor bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
torch::Tensor const& dropout_mask,
const uint8_t *padding_mask,
float dropout_prob
)
{
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = q_seq_len;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
// Output Tensor Allocations
// torch::Tensor input_grads = torch::empty_like(output_grads);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
if (padding_mask == nullptr) {
dispatch_masked_scale_softmax_backward<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len);
} else{
dispatch_masked_scale_softmax_backward_masked_out<half, half, float,false>(
static_cast<half*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()),
reinterpret_cast<half const*>(softmax_results.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
static_cast<uint8_t const*>(padding_mask),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len,
heads);
}
//backward pass is completely in-place
return output_grads;
}
}
}
}
#include <torch/extension.h>
#include <vector>
namespace multihead_attn {
namespace self_bias {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const uint8_t* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const uint8_t*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
#include <torch/extension.h>
#include <vector>
#include <cuda_fp16.h>
namespace multihead_attn {
namespace self_bias_additive_mask {
namespace cublas_gemmex {
std::vector<torch::Tensor> fwd_cuda(
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases,
torch::Tensor const& output_biases,
const half* pad_mask,
float dropout_prob
);
std::vector<torch::Tensor> bwd_cuda(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
//torch::Tensor const& input_biases,
//torch::Tensor const& output_biases,
torch::Tensor const& dropout_mask,
float dropout_prob
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> fwd(
bool use_mask,
bool use_time_mask,
bool is_training,
int heads,
torch::Tensor const& inputs, torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& input_biases, torch::Tensor const& output_biases,
torch::Tensor const& pad_mask,
float dropout_prob
)
{
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
if (use_mask) {
AT_ASSERTM(pad_mask.dim() == 2, "expected 2D tensor");
AT_ASSERTM(pad_mask.type().scalarType() == at::ScalarType::Half, "Only Half is supported");
}
return fwd_cuda(
use_time_mask,
is_training,
heads,
inputs,
input_weights,
output_weights,
input_biases,
output_biases,
use_mask ? static_cast<const half*>(pad_mask.data_ptr()) : nullptr,
dropout_prob
);
}
std::vector<torch::Tensor> bwd(
int heads,
torch::Tensor const& output_grads,
torch::Tensor const& matmul2_results,
torch::Tensor const& dropout_results,
torch::Tensor const& softmax_results,
torch::Tensor const& input_lin_results,
torch::Tensor const& inputs,
torch::Tensor const& input_weights,
torch::Tensor const& output_weights,
torch::Tensor const& dropout_mask,
float dropout_prob
)
{
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(matmul2_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(dropout_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_lin_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM(inputs.dim() == 3, "expected 3D tensor");
AT_ASSERTM(input_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(output_weights.dim() == 2, "expected 2D tensor");
AT_ASSERTM(dropout_mask.dim() == 3, "expected 3D tensor");
AT_ASSERTM(output_grads.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(matmul2_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(softmax_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_lin_results.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(inputs.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(input_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(output_weights.type().scalarType() == at::ScalarType::Half, "Only HALF is supported");
AT_ASSERTM(dropout_mask.type().scalarType() == at::ScalarType::Byte, "Only BYTE is supported");
return bwd_cuda(
heads,
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_results,
inputs,
input_weights,
output_weights,
dropout_mask,
dropout_prob
);
}
} // end namespace cublas_gemmex
} // end namespace self
} // end namespace multihead_attn
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::fwd, "Self Multihead Attention with Bias -- Forward.");
m.def("backward", &multihead_attn::self_bias_additive_mask::cublas_gemmex::bwd, "Self Multihead Attention with Bias -- Backward.");
}
...@@ -153,9 +153,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -153,9 +153,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -200,7 +200,6 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -200,7 +200,6 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -357,9 +356,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -357,9 +356,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
...@@ -434,7 +433,6 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -434,7 +433,6 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad // Input Linear Wgrad
......
...@@ -176,9 +176,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -176,9 +176,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert(softmax_success); assert(softmax_success);
if (is_training) { if (is_training) {
apex_fused_dropout_cuda<half,float,uint32_t>( apex_fused_dropout_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(softmax_results.data_ptr()), static_cast<at::Half const*>(softmax_results.data_ptr()),
static_cast<half*>(dropout_results.data_ptr()), static_cast<at::Half*>(dropout_results.data_ptr()),
static_cast<uint8_t*>(dropout_mask.data_ptr()), static_cast<uint8_t*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
...@@ -224,23 +224,22 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -224,23 +224,22 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// End-of-block Dropout-Add // End-of-block Dropout-Add
if (is_training) { if (is_training) {
apex_dropout_add_cuda<half,float,uint32_t>( apex_dropout_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs.data_ptr()), static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
static_cast<uint8_t*>(dropout_add_mask.data_ptr()), static_cast<uint8_t*>(dropout_add_mask.data_ptr()),
total_tokens, total_tokens,
(1.0f - dropout_prob)); (1.0f - dropout_prob));
} else { } else {
apex_add_cuda<half,float,uint32_t>( apex_add_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_lin_results.data_ptr()), static_cast<at::Half const*>(output_lin_results.data_ptr()),
static_cast<half const*>(inputs.data_ptr()), static_cast<at::Half const*>(inputs.data_ptr()),
static_cast<half*>(outputs.data_ptr()), static_cast<at::Half*>(outputs.data_ptr()),
total_tokens); total_tokens);
} }
...@@ -309,6 +308,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -309,6 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
torch::Tensor input_weight_grads = torch::empty_like(input_weights); torch::Tensor input_weight_grads = torch::empty_like(input_weights);
torch::Tensor output_weight_grads = torch::empty_like(output_weights); torch::Tensor output_weight_grads = torch::empty_like(output_weights);
// Intermediate Tensor Allocations // Intermediate Tensor Allocations
torch::Tensor dropout_add_grads = torch::empty_like(output_grads);
torch::Tensor output_lin_grads = torch::empty_like(matmul2_results); torch::Tensor output_lin_grads = torch::empty_like(matmul2_results);
torch::Tensor matmul2_grads = torch::empty_like(dropout_results); torch::Tensor matmul2_grads = torch::empty_like(dropout_results);
torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results); torch::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);
...@@ -330,9 +330,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -330,9 +330,9 @@ std::vector<torch::Tensor> bwd_cuda(
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
static_cast<half*>(output_grads.data_ptr()), static_cast<at::Half*>(dropout_add_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_add_mask.data_ptr()), static_cast<uint8_t const*>(dropout_add_mask.data_ptr()),
total_tokens, total_tokens,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
...@@ -348,7 +348,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -348,7 +348,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(output_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
CUDA_R_16F, CUDA_R_16F,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
...@@ -420,9 +420,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -420,9 +420,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
static_cast<half*>(matmul2_grads.data_ptr()), static_cast<at::Half*>(matmul2_grads.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()), static_cast<uint8_t const*>(dropout_mask.data_ptr()),
dropout_elems, dropout_elems,
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
......
This diff is collapsed.
...@@ -33,8 +33,10 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, ...@@ -33,8 +33,10 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m,
float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) { float beta, half *c, long ldc, long strideC, long batchCount, cublasGemmAlgo_t algo=CUBLAS_GEMM_DEFAULT_TENSOR_OP) {
cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb); cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
...@@ -131,7 +133,7 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k, ...@@ -131,7 +133,7 @@ void CutlassGemm_FP32Accum(cudaStream_t stream, long m, long n, long k,
AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object."); AT_ASSERTM(result == 0, "Failed to initialize CUTLASS Gemm::Params object.");
// Launch the CUTLASS GEMM kernel. // Launch the CUTLASS GEMM kernel.
THCudaCheck(Gemm::launch(params)); THCudaCheck(Gemm::launch(params, stream));
// Update batched GEMM params based on completed work // Update batched GEMM params based on completed work
batchesLeft = batchesLeft - iterBatchCount; batchesLeft = batchesLeft - iterBatchCount;
......
#include <torch/extension.h>
void multi_tensor_lamb_compute_update_term_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
const int step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm);
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate,
at::Tensor per_tensor_decay,
bool use_nvlamb);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("multi_tensor_lamb_compute_update_term", &multi_tensor_lamb_compute_update_term_cuda,
"Computes update term for LAMB optimizer");
m.def("multi_tensor_lamb_update_weights", &multi_tensor_lamb_update_weights_cuda,
"Applies update term for LAMB optimizer");
}
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template <typename FROM_T, typename TO_T>
__device__ void convert(const FROM_T vi, TO_T& vo)
{
vo = static_cast<TO_T>(vi);
}
template <>
__device__ void convert(const float vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = vi;
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, float& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = static_cast<float>(t.as_half);
}
template <>
__device__ void convert(const at::Half vi, uint8_t& vo)
{
union S
{
float as_float;
int as_int;
};
S s;
s.as_float = static_cast<float>(vi);
s.as_int = s.as_int & 0xFF800000;
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_half = static_cast<at::Half>(vi + s.as_float / 8.0f);
vo = t.as_byte[1];
}
template <>
__device__ void convert(const uint8_t vi, at::Half& vo)
{
union T
{
at::Half as_half;
uint8_t as_byte[2];
};
T t;
t.as_byte[0] = 0;
t.as_byte[1] = vi;
vo = t.as_half;
}
typedef enum{
MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode
} adamMode_t;
template<typename T, typename GRAD_T, typename MATH_T>
struct DistOptLAMBStage1Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<5>& tl,
const MATH_T* per_tensor_beta1,
const MATH_T* per_tensor_beta2,
const MATH_T* per_tensor_beta3,
const int* per_tensor_bias_correction,
const int step,
const MATH_T* per_tensor_epsilon,
adamMode_t mode,
const MATH_T* per_tensor_decay,
const MATH_T global_grad_norm,
const MATH_T max_global_grad_norm)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : (MATH_T) 1.0;
MATH_T beta1 = per_tensor_beta1[tensor_num];
MATH_T beta2 = per_tensor_beta1[tensor_num];
MATH_T beta3 = per_tensor_beta1[tensor_num];
MATH_T beta1_correction, beta2_correction;
if (per_tensor_bias_correction[tensor_num] == 1) {
beta1_correction = 1 - pow(beta1, (MATH_T) step);
beta2_correction = 1 - pow(beta2, (MATH_T) step);
} else {
beta1_correction = (MATH_T) 1.0;
beta2_correction = (MATH_T) 1.0;
}
MATH_T epsilon = per_tensor_epsilon[tensor_num];
MATH_T decay = per_tensor_decay[tensor_num];
GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
MATH_T* u = (MATH_T*)tl.addresses[4][tensor_loc];
u += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(g) &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v))
{
GRAD_T l_g[ILP];
T l_p[ILP];
T l_m[ILP];
T l_v[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_g[ii] = l_g[ii];
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = l_p[ii];
}
r_m[ii] = l_m[ii];
r_v[ii] = l_v[ii];
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
l_m[ii] = r_m[ii];
l_v[ii] = r_v[ii];
}
// store
load_store(u, r_p, i_start, 0);
load_store(m, l_m, i_start, 0);
load_store(v, l_v, i_start, 0);
}
}
else
{
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
u[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template<typename T, typename GRAD_T, typename MATH_T>
struct DistOptLAMBStage2Functor
{
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<3>& tl,
const MATH_T* per_tensor_param_norm,
const MATH_T* per_tensor_update_norm,
const MATH_T learning_rate,
const MATH_T* per_tensor_decay,
bool use_nvlamb)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T decay = per_tensor_decay[tensor_num];
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if (use_nvlamb || (decay != (MATH_T) 0.0))
{
MATH_T param_norm = per_tensor_param_norm[tensor_num];
MATH_T update_norm = per_tensor_update_norm[tensor_num];
ratio = (update_norm != (MATH_T) 0.0 && param_norm != (MATH_T) 0.0) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
MATH_T* update = (MATH_T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
GRAD_T* p_copy = (GRAD_T*)tl.addresses[2][tensor_loc];
p_copy += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(update))
{
T r_p[ILP];
MATH_T r_update[ILP];
GRAD_T r_p_copy[ILP];
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
load_store(r_p, p, 0, i_start);
load_store(r_update, update, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * r_update[ii]);
convert(r_p[ii], r_p_copy[ii]);
}
load_store(p, r_p, i_start, 0);
load_store(p_copy, r_p_copy, i_start, 0);
}
}
else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{
MATH_T r_p[ILP];
MATH_T r_update[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
convert(r_p[ii], p_copy[i]);
}
}
}
}
}
};
void multi_tensor_lamb_compute_update_term_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_beta1,
at::Tensor per_tensor_beta2,
at::Tensor per_tensor_beta3,
at::Tensor per_tensor_bias_correction,
const int step,
at::Tensor per_tensor_epsilon,
const int mode,
at::Tensor per_tensor_decay,
const float global_grad_norm,
const float max_global_grad_norm)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 1, "lamb_stage_1",
DISPATCH_FLOAT_AND_HALF(tensor_lists[4][0].scalar_type(), 2, "lamb_stage_1",
multi_tensor_apply<5>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistOptLAMBStage1Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_beta1.DATA_PTR<scalar_t_2>(),
per_tensor_beta2.DATA_PTR<scalar_t_2>(),
per_tensor_beta3.DATA_PTR<scalar_t_2>(),
per_tensor_bias_correction.DATA_PTR<int>(),
step,
per_tensor_epsilon.DATA_PTR<scalar_t_2>(),
(adamMode_t) mode,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
(scalar_t_2) global_grad_norm,
(scalar_t_2) max_global_grad_norm); )))
AT_CUDA_CHECK(cudaGetLastError());
}
void multi_tensor_lamb_update_weights_cuda(
int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::Tensor per_tensor_param_norm,
at::Tensor per_tensor_update_norm,
const float learning_rate,
at::Tensor per_tensor_decay,
bool use_nvlamb)
{
using namespace at;
DISPATCH_FLOAT_AND_HALF(tensor_lists[1][0].scalar_type(), 0, "lamb_stage_2",
DISPATCH_FLOAT_HALF_AND_BYTE(tensor_lists[2][0].scalar_type(), 1, "lamb_stage_2",
DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 2, "lamb_stage_2",
multi_tensor_apply<3>(
BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
DistOptLAMBStage2Functor<scalar_t_0, scalar_t_1, scalar_t_2>(),
per_tensor_param_norm.DATA_PTR<scalar_t_2>(),
per_tensor_update_norm.DATA_PTR<scalar_t_2>(),
(scalar_t_2) learning_rate,
per_tensor_decay.DATA_PTR<scalar_t_2>(),
use_nvlamb); )))
AT_CUDA_CHECK(cudaGetLastError());
}
import torch
import torch.nn.functional as F
import argparse
from apex.contrib.multihead_attn import SelfMultiheadAttn
from apex.contrib.multihead_attn import EncdecMultiheadAttn
parser = argparse.ArgumentParser(description='Multihead Attention Standalone Test')
parser.add_argument('--seq-length', default=64, type=int, help='Sequence Length of Input')
parser.add_argument('--num-seqs-start', default=5, type=int, help='Start Range of Number of Sequences')
parser.add_argument('--num-seqs-stop', default=80, type=int, help='Stop Range of Number of Sequences')
parser.add_argument('--num-seqs-inc', default=5, type=int, help='Range Increment of Number of Sequences')
parser.add_argument('--trials', default=20, type=int, help='Number of Trials to Execute')
parser.add_argument('--warmup-trials', default=5, type=int, help='Warmup Trials to discard')
parser.add_argument('--layers', default=18, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--seed-start', default=1, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--seed-end', default=100, type=int, help='Attention Layers to Execute to Gain CPU/GPU Time Overlap')
parser.add_argument('--hidden-dim', default=1024, type=int, help='Multihead Attention hidden dimension')
parser.add_argument('--heads', default=16, type=int, help='Number of Multihead Attention heads')
parser.add_argument('--encdec-attn', action='store_true', help='Use Encoder-Decoder Attention instead of Self Attention.')
parser.add_argument('--norm-add', action='store_true', help='Include Layer Norm and Dropout-Add in Multihead Attention block.')
parser.add_argument('--ref', action='store_true', help='Reference implementation in python pytorch.')
parser.add_argument('--native', action='store_true', help='torch.nn.MultitheadAttention Version.')
parser.add_argument('--fwd', action='store_true', help='Only execute Fwd Pass.')
parser.add_argument('--eval', action='store_true', help='Inference only, no backward pass.')
args = parser.parse_args()
assert args.seq_length % 64 == 0, "Sequence Length should be a multiple of 64!"
if not torch.cuda.is_available():
raise NotImplementedError('Running on CPU is not supported')
torch.cuda.set_device(0)
dropout_prob = 0.1
for seed in range(args.seed_start, args.seed_end+1) :
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
ref_layer = None
if args.encdec_attn :
ref_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')
else :
ref_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='default')
ref_layer.cuda()
ref_layer.half()
ref_layer.reset_parameters()
ref_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
ref_inputs_kv = None
if args.encdec_attn :
ref_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
ref_grads = torch.randn_like(ref_inputs)
ref_outputs,_ = ref_layer.forward(ref_inputs,
ref_inputs_kv,
ref_inputs_kv,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=(not args.eval))
ref_outputs.backward(ref_grads)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
tst_layer = None
if args.encdec_attn :
tst_layer = EncdecMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')
else:
tst_layer = SelfMultiheadAttn(args.hidden_dim, args.heads, dropout=dropout_prob, bias=False, include_norm_add=args.norm_add, impl='fast')
tst_layer.cuda()
tst_layer.half()
tst_layer.reset_parameters()
tst_inputs = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
tst_inputs_kv = None
if args.encdec_attn :
tst_inputs_kv = torch.randn(args.seq_length, args.num_seqs_start, args.hidden_dim, dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
assert torch.equal(ref_inputs,tst_inputs), "ERROR: Inputs are different!"
tst_grads = torch.randn_like(tst_inputs)
tst_outputs,_ = tst_layer.forward(tst_inputs,
tst_inputs_kv,
tst_inputs_kv,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=(not args.eval))
tst_outputs.backward(tst_grads)
fwd_close = torch.equal(ref_outputs, tst_outputs)
bwd_close = torch.equal(ref_inputs.grad, tst_inputs.grad)
diff_fwd = ref_outputs - tst_outputs
diff_cnt_fwd = diff_fwd.ne(0.0).sum()
diff_accum_fwd = diff_fwd.abs().sum()
diff_bwd = ref_inputs.grad - tst_inputs.grad
diff_cnt_bwd = diff_bwd.ne(0.0).sum()
diff_accum_bwd = diff_bwd.abs().sum()
print(">>> Seed: ", seed, fwd_close, diff_cnt_fwd.item(), diff_accum_fwd.item(), bwd_close, diff_cnt_bwd.item(), diff_accum_bwd.item())
from .self_multihead_attn import SelfMultiheadAttn from .self_multihead_attn import SelfMultiheadAttn
from .encdec_multihead_attn import EncdecMultiheadAttn from .encdec_multihead_attn import EncdecMultiheadAttn
from .mask_softmax_dropout_func import fast_mask_softmax_dropout_func
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment