Unverified Commit 29b36315 authored by Hubert Lu's avatar Hubert Lu Committed by GitHub
Browse files

Cherry-picked the commit from upstream for faster --fast_multihead_attn build (#76)



* Faster `--fast_multihead_attn` build (#1245)

* merge .so files

* odr

* fix build

* update import

* apply psf/black with max line length of 120

* update

* fix

* update

* build fixed again but undefined symbol again

* fix 2, still layer norm grad is undefined

* remove unused cpp files

* without layer_norm.cuh, import works

* import fast_multihead_attn works...

but why? Was unnecessary `#include "layer_norm.cuh"` was the culprit
causing .shared objects not to be able to link `HostApplyLayerNorm` and
`HostLayerNormGradient`?

* clean up layer norm

* Fix some bugs
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 5ecad142
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "dropout.h" #include "dropout.cuh"
#include "layer_norm.h" #include "layer_norm.cuh"
#include "softmax.h" #include "softmax.cuh"
#include "strided_batched_gemm.h" #include "strided_batched_gemm.cuh"
namespace multihead_attn { namespace multihead_attn {
namespace self_norm_add { namespace self_norm_add {
...@@ -88,6 +88,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -88,6 +88,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
rocblas_int flags = 0;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half, float>( HostApplyLayerNorm<at::Half, float>(
...@@ -109,22 +111,22 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -109,22 +111,22 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
q_lin_results_ptr, q_lin_results_ptr,
c_type, rocblas_datatype_f16_r /*c_type*/,
output_lin_dim, output_lin_dim,
q_lin_results_ptr, q_lin_results_ptr,
d_type, rocblas_datatype_f16_r /*d_type*/,
output_lin_dim, output_lin_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
...@@ -148,7 +150,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -148,7 +150,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -203,7 +205,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -203,7 +205,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
head_dim*attn_batches, head_dim*attn_batches,
head_dim, head_dim,
attn_batches, attn_batches,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -214,21 +216,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -214,21 +216,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_lin_results.data_ptr()), static_cast<void*>(output_lin_results.data_ptr()),
d_type, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
...@@ -317,6 +319,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -317,6 +319,8 @@ std::vector<torch::Tensor> bwd_cuda(
char a_layout_t{'t'}; char a_layout_t{'t'};
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
rocblas_int flags = 0;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
...@@ -345,21 +349,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -345,21 +349,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(output_weights.data_ptr()), static_cast<const void*>(output_weights.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
d_type, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
...@@ -371,21 +375,21 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -371,21 +375,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches, batches,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(matmul2_results.data_ptr()), static_cast<const void*>(matmul2_results.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()), static_cast<const void*>(dropout_add_grads.data_ptr()),
b_type, rocblas_datatype_f16_r /*b_type*/,
embed_dim, embed_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
d_type, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -409,7 +413,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -409,7 +413,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, k_seq_len,
k_seq_len*q_seq_len, k_seq_len*q_seq_len,
attn_batches, attn_batches,
flags); flags);
// Matmul2 Dgrad2 // Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -432,7 +436,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -432,7 +436,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -472,7 +476,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -472,7 +476,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Matmul1 Dgrad2 // Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -495,7 +499,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -495,7 +499,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim, lead_dim,
batch_stride, batch_stride,
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
...@@ -506,22 +510,22 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -506,22 +510,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim, output_lin_dim,
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
static_cast<const void*>(input_weights.data_ptr()), static_cast<const void*>(input_weights.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
b_type, rocblas_datatype_f16_r /*b_type*/,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
//static_cast<void*>(input_grads.data_ptr()), //static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()), static_cast<void*>(input_lin_grads.data_ptr()),
d_type, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
...@@ -534,27 +538,27 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -534,27 +538,27 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<const void*>(&alpha), static_cast<const void*>(&alpha),
//static_cast<const void*>(inputs.data_ptr()), //static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()), static_cast<const void*>(lyr_nrm_results.data_ptr()),
a_type, rocblas_datatype_f16_r /*a_type*/,
embed_dim, embed_dim,
static_cast<const void*>(q_lin_grads_ptr), static_cast<const void*>(q_lin_grads_ptr),
b_type, rocblas_datatype_f16_r /*b_type*/,
output_lin_dim, output_lin_dim,
static_cast<const void*>(&beta), static_cast<const void*>(&beta),
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
c_type, rocblas_datatype_f16_r /*c_type*/,
embed_dim, embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
d_type, rocblas_datatype_f16_r /*d_type*/,
embed_dim, embed_dim,
compute_type, rocblas_datatype_f32_r /*compute_type*/,
algo, rocblas_gemm_algo_standard /*algo*/,
solution_index, 0 /*solution_index*/,
flags)); flags));
// Fused Layer Norm Bwd with Residual Add // Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>( HostLayerNormGradient<half, float>(
static_cast<const half *>(input_lin_grads.data_ptr()), static_cast<const half *>(input_lin_grads.data_ptr()),
static_cast<half const *>(output_grads.data_ptr()), static_cast<const half *>(output_grads.data_ptr()),
static_cast<const float *>(lyr_nrm_mean.data_ptr()), static_cast<const float *>(lyr_nrm_mean.data_ptr()),
static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs, static_cast<const float *>(lyr_nrm_invvar.data_ptr()), inputs,
static_cast<int>(batches), // n1 static_cast<int>(batches), // n1
......
#pragma once #pragma once
#include "philox.h" #include "philox.cuh"
#include <ATen/cuda/CUDAGraphsUtils.cuh> #include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h> #include <curand_kernel.h>
...@@ -27,6 +27,14 @@ namespace { ...@@ -27,6 +27,14 @@ namespace {
template <typename Datatype, int ELEMENTS_PER_LDG> template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); __device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value,
const uint8_t *src);
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst,
const Datatype *additive_mask);
template <> template <>
__device__ __inline__ void copy_vector<__half, 1>(__half *dst, __device__ __inline__ void copy_vector<__half, 1>(__half *dst,
const __half *src) { const __half *src) {
...@@ -55,10 +63,6 @@ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, ...@@ -55,10 +63,6 @@ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
*((half2 *)dst) = *((half2 *)src); *((half2 *)dst) = *((half2 *)src);
} }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_mask(Datatype *dst, Datatype value,
const uint8_t *src);
template <> template <>
__device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, __device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value,
const uint8_t *src) { const uint8_t *src) {
...@@ -66,14 +70,13 @@ __device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value, ...@@ -66,14 +70,13 @@ __device__ __inline__ void apply_mask<__half, 1>(__half *dst, __half value,
*dst = value; *dst = value;
} }
} }
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void apply_additive_mask(Datatype *dst,
const Datatype *additive_mask);
template <> template <>
__device__ __inline__ void __device__ __inline__ void
apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) { apply_additive_mask<__half, 1>(__half *dst, const __half *additive_mask) {
*dst += *additive_mask; *dst += *additive_mask;
} }
template <> template <>
__device__ __inline__ void __device__ __inline__ void
apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) { apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
...@@ -82,7 +85,6 @@ apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) { ...@@ -82,7 +85,6 @@ apply_additive_mask<__half, 4>(__half *dst, const __half *additive_mask) {
*(dst + 2) += *(additive_mask + 2); *(dst + 2) += *(additive_mask + 2);
*(dst + 3) += *(additive_mask + 3); *(dst + 3) += *(additive_mask + 3);
} }
} // namespace
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Warp Softmax forward // Warp Softmax forward
...@@ -3142,4 +3144,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad, ...@@ -3142,4 +3144,4 @@ bool dispatch_masked_softmax_backward(output_t *grad_input, const input_t *grad,
} }
return false; return false;
} }
} // namespace
#pragma once
#include <iostream> #include <iostream>
#include <vector> #include <vector>
...@@ -15,18 +16,19 @@ ...@@ -15,18 +16,19 @@
//#include "cutlass/gemm/wmma_gemm_traits.h" //#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
/*
rocblas_datatype a_type = rocblas_datatype_f16_r; rocblas_datatype a_type = rocblas_datatype_f16_r; // OK
rocblas_datatype b_type = rocblas_datatype_f16_r; rocblas_datatype b_type = rocblas_datatype_f16_r; // OK
rocblas_datatype c_type = rocblas_datatype_f16_r; rocblas_datatype c_type = rocblas_datatype_f16_r; // OK
rocblas_datatype d_type = rocblas_datatype_f16_r; rocblas_datatype d_type = rocblas_datatype_f16_r;
rocblas_datatype compute_type = rocblas_datatype_f32_r; rocblas_datatype compute_type = rocblas_datatype_f32_r;
rocblas_gemm_algo algo = rocblas_gemm_algo_standard; rocblas_gemm_algo algo = rocblas_gemm_algo_standard;
int32_t solution_index = 0; int32_t solution_index = 0;
rocblas_int flags = 0; rocblas_int flags = 0;
*/
namespace {
cublasOperation_t convertTransToCublasOperation(char trans) { cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') if (trans == 't')
return CUBLAS_OP_T; return CUBLAS_OP_T;
...@@ -54,26 +56,26 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, ...@@ -54,26 +56,26 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, a_type, (int)lda, strideA, (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, b_type, (int)ldb, strideB, b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, c_type, (int)ldc, strideC, (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, d_type, int(ldd), strideD, d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, compute_type, algo, solution_index, flags)); (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
} }
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 'n') ) { } else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 't') ) { } else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else { } else {
AT_ASSERTM(false, "TransA and TransB are invalid"); AT_ASSERTM(false, "TransA and TransB are invalid");
} }
...@@ -127,7 +129,7 @@ void HgemmStridedBatched(char transa, char transb, long m, ...@@ -127,7 +129,7 @@ void HgemmStridedBatched(char transa, char transb, long m,
// gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
// b, ldb, strideB, beta, c, ldc, strideC, batchCount); // b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, flags); b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, 0 /*flags*/);
} }
} // namespace
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include <ATen/cuda/CUDAGraphsUtils.cuh> #include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
#include "philox.h" #include "philox.cuh"
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
// width should be a power of 2 and should be less than warpSize. // width should be a power of 2 and should be less than warpSize.
......
...@@ -5,16 +5,17 @@ from torch import nn ...@@ -5,16 +5,17 @@ from torch import nn
from torch.nn import Parameter from torch.nn import Parameter
import torch.nn.functional as F import torch.nn.functional as F
from .encdec_multihead_attn_func import encdec_attn_func from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm
if hasattr(torch._C, '_jit_set_profiling_executor') : if hasattr(torch._C, "_jit_set_profiling_executor"):
torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_executor(False)
if hasattr(torch._C, '_jit_set_profiling_mode') : if hasattr(torch._C, "_jit_set_profiling_mode"):
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
@torch.jit.script @torch.jit.script
def jit_dropout_add(x, residual, prob, is_training): def jit_dropout_add(x, residual, prob, is_training):
# type: (Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, float, bool) -> Tensor
...@@ -28,7 +29,8 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -28,7 +29,8 @@ class EncdecMultiheadAttn(nn.Module):
See "Attention Is All You Need" for more details. See "Attention Is All You Need" for more details.
""" """
def __init__(self, embed_dim, num_heads, dropout=0., bias=False, include_norm_add=False, impl='fast'):
def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_add=False, impl="fast"):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
...@@ -38,43 +40,49 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -38,43 +40,49 @@ class EncdecMultiheadAttn(nn.Module):
self.bias = bias self.bias = bias
self.include_norm_add = include_norm_add self.include_norm_add = include_norm_add
self.impl = impl self.impl = impl
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim ** -0.5
self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim)) self.in_proj_weight_q = Parameter(torch.Tensor(embed_dim, embed_dim))
self.in_proj_weight_kv = Parameter(torch.Tensor(2*embed_dim, embed_dim)) self.in_proj_weight_kv = Parameter(torch.Tensor(2 * embed_dim, embed_dim))
self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) self.out_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
if self.bias: if self.bias:
assert impl != 'fast', "ERROR! The Fast implementation does not support biases!" assert impl != "fast", "ERROR! The Fast implementation does not support biases!"
self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim)) self.in_proj_bias_q = Parameter(torch.Tensor(embed_dim))
self.in_proj_bias_kv = Parameter(torch.Tensor(2*embed_dim)) self.in_proj_bias_kv = Parameter(torch.Tensor(2 * embed_dim))
self.out_proj_bias = Parameter(torch.Tensor(embed_dim)) self.out_proj_bias = Parameter(torch.Tensor(embed_dim))
else: else:
self.register_parameter('in_proj_bias_q', None) self.register_parameter("in_proj_bias_q", None)
self.register_parameter('in_proj_bias_kv', None) self.register_parameter("in_proj_bias_kv", None)
self.in_proj_bias_q = None self.in_proj_bias_q = None
self.in_proj_bias_kv = None self.in_proj_bias_kv = None
self.out_proj_bias = None self.out_proj_bias = None
if self.include_norm_add: if self.include_norm_add:
if impl == 'fast': if impl == "fast":
self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_gamma_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim)) self.lyr_nrm_beta_weights = Parameter(torch.Tensor(embed_dim))
self.lyr_nrm = None self.lyr_nrm = None
else: else:
self.register_parameter('lyr_norm_gamma_weights', None) self.register_parameter("lyr_norm_gamma_weights", None)
self.register_parameter('lyr_norm_beta_weights', None) self.register_parameter("lyr_norm_beta_weights", None)
self.lyr_nrm_gamma_weights = None self.lyr_nrm_gamma_weights = None
self.lyr_nrm_beta_weights = None self.lyr_nrm_beta_weights = None
self.lyr_nrm = FusedLayerNorm(embed_dim) self.lyr_nrm = FusedLayerNorm(embed_dim)
self.reset_parameters() self.reset_parameters()
if self.include_norm_add: if self.include_norm_add:
if impl == 'fast' : self.attn_func = fast_encdec_attn_norm_add_func if impl == "fast":
elif impl == 'default' : self.attn_func = encdec_attn_func self.attn_func = fast_encdec_attn_norm_add_func
else : assert False, "Unsupported impl: {} !".format(impl) elif impl == "default":
self.attn_func = encdec_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
else: else:
if impl == 'fast' : self.attn_func = fast_encdec_attn_func if impl == "fast":
elif impl == 'default' : self.attn_func = encdec_attn_func self.attn_func = fast_encdec_attn_func
else : assert False, "Unsupported impl: {} !".format(impl) elif impl == "default":
self.attn_func = encdec_attn_func
else:
assert False, "Unsupported impl: {} !".format(impl)
def reset_parameters(self): def reset_parameters(self):
nn.init.xavier_uniform_(self.in_proj_weight_q) nn.init.xavier_uniform_(self.in_proj_weight_q)
...@@ -85,11 +93,11 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -85,11 +93,11 @@ class EncdecMultiheadAttn(nn.Module):
nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5)) nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
nn.init.xavier_uniform_(self.out_proj_weight) nn.init.xavier_uniform_(self.out_proj_weight)
if self.bias: if self.bias:
nn.init.constant_(self.in_proj_bias_q, 0.) nn.init.constant_(self.in_proj_bias_q, 0.0)
nn.init.constant_(self.in_proj_bias_kv, 0.) nn.init.constant_(self.in_proj_bias_kv, 0.0)
nn.init.constant_(self.out_proj_bias, 0.) nn.init.constant_(self.out_proj_bias, 0.0)
if self.include_norm_add: if self.include_norm_add:
if self.impl == 'fast' : if self.impl == "fast":
nn.init.ones_(self.lyr_nrm_gamma_weights) nn.init.ones_(self.lyr_nrm_gamma_weights)
nn.init.zeros_(self.lyr_nrm_beta_weights) nn.init.zeros_(self.lyr_nrm_beta_weights)
else: else:
...@@ -106,7 +114,7 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -106,7 +114,7 @@ class EncdecMultiheadAttn(nn.Module):
""" """
if key_padding_mask is not None: if key_padding_mask is not None:
assert (attn_mask is None), "ERROR attn_mask and key_padding_mask should not be both defined!" assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
mask = key_padding_mask mask = key_padding_mask
elif attn_mask is not None: elif attn_mask is not None:
mask = attn_mask mask = attn_mask
...@@ -114,28 +122,73 @@ class EncdecMultiheadAttn(nn.Module): ...@@ -114,28 +122,73 @@ class EncdecMultiheadAttn(nn.Module):
mask = None mask = None
if self.include_norm_add: if self.include_norm_add:
if self.impl == 'fast': if self.impl == "fast":
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key, outputs = self.attn_func(
self.lyr_nrm_gamma_weights, self.lyr_nrm_beta_weights, attn_mask is not None,
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout) is_training,
self.num_heads,
query,
key,
self.lyr_nrm_gamma_weights,
self.lyr_nrm_beta_weights,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
mask,
self.dropout,
)
else: else:
lyr_nrm_results = self.lyr_nrm(query) lyr_nrm_results = self.lyr_nrm(query)
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, lyr_nrm_results, key, outputs = self.attn_func(
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, attn_mask is not None,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias, is_training,
mask, self.dropout) self.num_heads,
self.scaling,
lyr_nrm_results,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
self.in_proj_bias_q,
self.in_proj_bias_kv,
self.out_proj_bias,
mask,
self.dropout,
)
if is_training: if is_training:
outputs = jit_dropout_add(outputs, query, self.dropout, is_training) outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
else: else:
outputs = outputs + query outputs = outputs + query
else: else:
if self.impl == 'fast': if self.impl == "fast":
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, query, key, outputs = self.attn_func(
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, mask, self.dropout) attn_mask is not None,
is_training,
self.num_heads,
query,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
mask,
self.dropout,
)
else: else:
outputs = self.attn_func(attn_mask is not None, is_training, self.num_heads, self.scaling, query, key, outputs = self.attn_func(
self.in_proj_weight_q, self.in_proj_weight_kv, self.out_proj_weight, attn_mask is not None,
self.in_proj_bias_q, self.in_proj_bias_kv, self.out_proj_bias, is_training,
mask, self.dropout) self.num_heads,
self.scaling,
query,
key,
self.in_proj_weight_q,
self.in_proj_weight_kv,
self.out_proj_weight,
self.in_proj_bias_q,
self.in_proj_bias_kv,
self.out_proj_bias,
mask,
self.dropout,
)
return outputs,None return outputs, None
import torch import torch
import fast_encdec_multihead_attn
import fast_multihead_attn
class FastEncdecAttnFunc(torch.autograd.Function): class FastEncdecAttnFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob): def forward(
heads_t = torch.tensor([heads]) ctx,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = pad_mask is not None
input_lin_q_results, \ (
input_lin_kv_results, \ input_lin_q_results,
softmax_results, \ input_lin_kv_results,
dropout_results, \ softmax_results,
dropout_mask, \ dropout_results,
matmul2_results, \ dropout_mask,
outputs = \ matmul2_results,
fast_encdec_multihead_attn.forward( \ outputs,
use_mask, \ ) = fast_multihead_attn.encdec_multihead_attn_forward(
use_time_mask, \ use_mask,
is_training, \ use_time_mask,
heads, \ is_training,
inputs_q, \ heads,
inputs_kv, \ inputs_q,
input_weights_q, \ inputs_kv,
input_weights_kv, \ input_weights_q,
output_weights, \ input_weights_kv,
pad_mask if use_mask else null_tensor, \ output_weights,
dropout_prob) pad_mask if use_mask else null_tensor,
dropout_prob,
)
ctx.save_for_backward(heads_t, \ ctx.save_for_backward(
matmul2_results, \ heads_t,
dropout_results, \ matmul2_results,
softmax_results, \ dropout_results,
input_lin_q_results, \ softmax_results,
input_lin_kv_results, \ input_lin_q_results,
inputs_q, \ input_lin_kv_results,
inputs_kv, \ inputs_q,
input_weights_q, \ inputs_kv,
input_weights_kv, \ input_weights_q,
output_weights, \ input_weights_kv,
dropout_mask, \ output_weights,
dropout_prob_t) dropout_mask,
dropout_prob_t,
)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
heads_t, \ (
matmul2_results, \ heads_t,
dropout_results, \ matmul2_results,
softmax_results, \ dropout_results,
input_lin_q_results, \ softmax_results,
input_lin_kv_results, \ input_lin_q_results,
inputs_q, \ input_lin_kv_results,
inputs_kv, \ inputs_q,
input_weights_q, \ inputs_kv,
input_weights_kv, \ input_weights_q,
output_weights, \ input_weights_kv,
dropout_mask, \ output_weights,
dropout_prob_t = ctx.saved_tensors dropout_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_q_grads,
input_kv_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
) = fast_multihead_attn.encdec_multihead_attn_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
inputs_q,
inputs_kv,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_prob_t[0],
)
input_q_grads, \ return (
input_kv_grads, \ None,
input_weight_q_grads, \ None,
input_weight_kv_grads, \ None,
output_weight_grads = \ input_q_grads,
fast_encdec_multihead_attn.backward( \ input_kv_grads,
heads_t[0], \ input_weight_q_grads,
output_grads, \ input_weight_kv_grads,
matmul2_results, \ output_weight_grads,
dropout_results, \ None,
softmax_results, \ None,
input_lin_q_results, \ )
input_lin_kv_results, \
inputs_q, \
inputs_kv, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_prob_t[0])
return None, None, None, input_q_grads, input_kv_grads, input_weight_q_grads, input_weight_kv_grads, output_weight_grads, None, None
fast_encdec_attn_func = FastEncdecAttnFunc.apply fast_encdec_attn_func = FastEncdecAttnFunc.apply
...@@ -6,125 +6,154 @@ ...@@ -6,125 +6,154 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch import torch
import fast_encdec_multihead_attn_norm_add
import fast_multihead_attn
class FastEncdecAttnNormAddFunc(torch.autograd.Function): class FastEncdecAttnNormAddFunc(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, use_time_mask, is_training, heads, inputs_q, inputs_kv, lyr_nrm_gamma_weights, lyr_nrm_beta_weights, input_weights_q, input_weights_kv, output_weights, pad_mask, dropout_prob): def forward(
heads_t = torch.tensor([heads]) ctx,
use_time_mask,
is_training,
heads,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
pad_mask,
dropout_prob,
):
heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = pad_mask is not None
lyr_nrm_results, \ (
lyr_nrm_mean, \ lyr_nrm_results,
lyr_nrm_invvar, \ lyr_nrm_mean,
input_lin_q_results, \ lyr_nrm_invvar,
input_lin_kv_results, \ input_lin_q_results,
softmax_results, \ input_lin_kv_results,
dropout_results, \ softmax_results,
dropout_mask, \ dropout_results,
matmul2_results, \ dropout_mask,
dropout_add_mask, \ matmul2_results,
outputs = \ dropout_add_mask,
fast_encdec_multihead_attn_norm_add.forward( \ outputs,
use_mask, \ ) = fast_multihead_attn.encdec_multihead_attn_norm_add_forward(
use_time_mask, \ use_mask,
is_training, \ use_time_mask,
heads, \ is_training,
inputs_q, \ heads,
inputs_kv, \ inputs_q,
lyr_nrm_gamma_weights, \ inputs_kv,
lyr_nrm_beta_weights, \ lyr_nrm_gamma_weights,
input_weights_q, \ lyr_nrm_beta_weights,
input_weights_kv, \ input_weights_q,
output_weights, \ input_weights_kv,
pad_mask if use_mask else null_tensor, \ output_weights,
dropout_prob) pad_mask if use_mask else null_tensor,
dropout_prob,
)
ctx.save_for_backward(heads_t, \ ctx.save_for_backward(
matmul2_results, \ heads_t,
dropout_results, \ matmul2_results,
softmax_results, \ dropout_results,
input_lin_q_results, \ softmax_results,
input_lin_kv_results, \ input_lin_q_results,
lyr_nrm_results, \ input_lin_kv_results,
lyr_nrm_mean, \ lyr_nrm_results,
lyr_nrm_invvar, \ lyr_nrm_mean,
inputs_q, \ lyr_nrm_invvar,
inputs_kv, \ inputs_q,
lyr_nrm_gamma_weights, \ inputs_kv,
lyr_nrm_beta_weights, \ lyr_nrm_gamma_weights,
input_weights_q, \ lyr_nrm_beta_weights,
input_weights_kv, \ input_weights_q,
output_weights, \ input_weights_kv,
dropout_mask, \ output_weights,
dropout_add_mask, \ dropout_mask,
dropout_prob_t) dropout_add_mask,
dropout_prob_t,
)
return outputs.detach() return outputs.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
heads_t, \ (
matmul2_results, \ heads_t,
dropout_results, \ matmul2_results,
softmax_results, \ dropout_results,
input_lin_q_results, \ softmax_results,
input_lin_kv_results, \ input_lin_q_results,
lyr_nrm_results, \ input_lin_kv_results,
lyr_nrm_mean, \ lyr_nrm_results,
lyr_nrm_invvar, \ lyr_nrm_mean,
inputs_q, \ lyr_nrm_invvar,
inputs_kv, \ inputs_q,
lyr_nrm_gamma_weights, \ inputs_kv,
lyr_nrm_beta_weights, \ lyr_nrm_gamma_weights,
input_weights_q, \ lyr_nrm_beta_weights,
input_weights_kv, \ input_weights_q,
output_weights, \ input_weights_kv,
dropout_mask, \ output_weights,
dropout_add_mask, \ dropout_mask,
dropout_prob_t = ctx.saved_tensors dropout_add_mask,
dropout_prob_t,
) = ctx.saved_tensors
(
input_q_grads,
input_kv_grads,
lyr_nrm_gamma_grads,
lyr_nrm_beta_grads,
input_weight_q_grads,
input_weight_kv_grads,
output_weight_grads,
) = fast_multihead_attn.encdec_multihead_attn_norm_add_backward(
heads_t[0],
output_grads,
matmul2_results,
dropout_results,
softmax_results,
input_lin_q_results,
input_lin_kv_results,
lyr_nrm_results,
lyr_nrm_mean,
lyr_nrm_invvar,
inputs_q,
inputs_kv,
lyr_nrm_gamma_weights,
lyr_nrm_beta_weights,
input_weights_q,
input_weights_kv,
output_weights,
dropout_mask,
dropout_add_mask,
dropout_prob_t[0],
)
input_q_grads, \ # import pdb; pdb.set_trace()
input_kv_grads, \ return (
lyr_nrm_gamma_grads, \ None,
lyr_nrm_beta_grads, \ None,
input_weight_q_grads, \ None,
input_weight_kv_grads, \ input_q_grads,
output_weight_grads = \ input_kv_grads,
fast_encdec_multihead_attn_norm_add.backward( \ lyr_nrm_gamma_grads,
heads_t[0], \ lyr_nrm_beta_grads,
output_grads, \ input_weight_q_grads,
matmul2_results, \ input_weight_kv_grads,
dropout_results, \ output_weight_grads,
softmax_results, \ None,
input_lin_q_results, \ None,
input_lin_kv_results, \ )
lyr_nrm_results, \
lyr_nrm_mean, \
lyr_nrm_invvar, \
inputs_q, \
inputs_kv, \
lyr_nrm_gamma_weights, \
lyr_nrm_beta_weights, \
input_weights_q, \
input_weights_kv, \
output_weights, \
dropout_mask, \
dropout_add_mask, \
dropout_prob_t[0])
#import pdb; pdb.set_trace()
return None, None, None, \
input_q_grads, \
input_kv_grads, \
lyr_nrm_gamma_grads, \
lyr_nrm_beta_grads, \
input_weight_q_grads, \
input_weight_kv_grads, \
output_weight_grads, \
None, None
fast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply fast_encdec_attn_norm_add_func = FastEncdecAttnNormAddFunc.apply
import torch import torch
import fast_mask_softmax_dropout
import fast_additive_mask_softmax_dropout
import fast_multihead_attn
class MaskSoftmaxDropout(torch.autograd.Function) :
class MaskSoftmaxDropout(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob): def forward(ctx, is_training, heads, inputs, pad_mask, mask_additive, dropout_prob):
heads_t = torch.tensor([heads]) heads_t = torch.tensor([heads])
dropout_prob_t = torch.tensor([dropout_prob]) dropout_prob_t = torch.tensor([dropout_prob])
null_tensor = torch.tensor([]) null_tensor = torch.tensor([])
use_mask = (pad_mask is not None) use_mask = pad_mask is not None
use_mask_t = torch.tensor([use_mask]) use_mask_t = torch.tensor([use_mask])
mask_additive_t = torch.tensor([mask_additive]) mask_additive_t = torch.tensor([mask_additive])
if mask_additive: if mask_additive:
dropout_results, \ dropout_results, dropout_mask, softmax_results = fast_multihead_attn.additive_mask_softmax_dropout_forward(
dropout_mask, \ use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob
softmax_results = \ )
fast_additive_mask_softmax_dropout.forward( \ # fast_additive_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
else: else:
dropout_results, \ dropout_results, dropout_mask, softmax_results = fast_multihead_attn.mask_softmax_dropout_forward(
dropout_mask, \ use_mask, is_training, heads, inputs, pad_mask if use_mask else null_tensor, dropout_prob
softmax_results = \ )
fast_mask_softmax_dropout.forward( \ # fast_mask_softmax_dropout.forward( \
use_mask, \
is_training, \
heads, \
inputs, \
pad_mask if use_mask else null_tensor, \
dropout_prob)
ctx.save_for_backward( ctx.save_for_backward(
use_mask_t, \ use_mask_t,
heads_t, \ heads_t,
softmax_results, \ softmax_results,
dropout_mask, \ dropout_mask,
pad_mask if use_mask else null_tensor, \ pad_mask if use_mask else null_tensor,
mask_additive_t, \ mask_additive_t,
dropout_prob_t) dropout_prob_t,
)
return dropout_results.detach() return dropout_results.detach()
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
use_mask_t, \ (
heads_t, \ use_mask_t,
softmax_results, \ heads_t,
dropout_mask, \ softmax_results,
pad_mask, \ dropout_mask,
mask_additive_t, \ pad_mask,
dropout_prob_t = ctx.saved_tensors mask_additive_t,
dropout_prob_t,
) = ctx.saved_tensors
if mask_additive_t[0]: if mask_additive_t[0]:
input_grads = \ input_grads = fast_multihead_attn.additive_mask_softmax_dropout_backward(
fast_additive_mask_softmax_dropout.backward( \ use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, dropout_prob_t[0]
use_mask_t[0], \ )
heads_t[0], \ # fast_additive_mask_softmax_dropout.backward( \
output_grads, \
softmax_results, \
dropout_mask, \
dropout_prob_t[0])
else: else:
input_grads = \ input_grads = fast_multihead_attn.mask_softmax_dropout_backward(
fast_mask_softmax_dropout.backward( \ use_mask_t[0], heads_t[0], output_grads, softmax_results, dropout_mask, pad_mask, dropout_prob_t[0]
use_mask_t[0], \ )
heads_t[0], \ # fast_mask_softmax_dropout.backward( \
output_grads, \
softmax_results, \
dropout_mask, \
pad_mask, \
dropout_prob_t[0])
return None, None, input_grads, None, None, None return None, None, input_grads, None, None, None
fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply fast_mask_softmax_dropout_func = MaskSoftmaxDropout.apply
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment