Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9f899769
Unverified
Commit
9f899769
authored
Nov 02, 2021
by
Hubert Lu
Committed by
GitHub
Nov 02, 2021
Browse files
Merge pull request #56 from ROCmSoftwarePlatform/dev/hubertlu/multihead_attn
Enable multihead atten
parents
325246e4
62f06964
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
724 additions
and
403 deletions
+724
-403
.gitignore
.gitignore
+3
-0
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
+1
-1
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
...rc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
...contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+125
-67
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
...src/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
+3
-3
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+129
-72
apex/contrib/csrc/multihead_attn/layer_norm.h
apex/contrib/csrc/multihead_attn/layer_norm.h
+25
-14
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
+5
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
...ihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+105
-60
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
...trib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+96
-52
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+91
-46
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp
.../csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp
+3
-3
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+96
-49
apex/contrib/csrc/multihead_attn/softmax.h
apex/contrib/csrc/multihead_attn/softmax.h
+25
-18
No files found.
.gitignore
View file @
9f899769
...
...
@@ -4,3 +4,6 @@ build
docs/build
*~
__pycache__
*.hip
*_hip.*
*hip*
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
View file @
9f899769
...
...
@@ -183,4 +183,4 @@ void ln_fwd_cuda(
assert
(
false
&&
"Not implemented"
);
}
}
\ No newline at end of file
}
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp
→
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout
_cpp
.cpp
View file @
9f899769
File moved
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
View file @
9f899769
...
...
@@ -5,7 +5,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp
→
apex/contrib/csrc/multihead_attn/encdec_multihead_attn
_cpp
.cpp
View file @
9f899769
...
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
encdec
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
);
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace encdec
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
c
u
blas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
c
u
blas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
ro
cblas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
ro
cblas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
9f899769
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
@@ -22,7 +25,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
encdec
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -86,9 +89,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Q Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
...
...
@@ -96,20 +99,25 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear KV Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
...
...
@@ -117,17 +125,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
...
...
@@ -146,6 +159,9 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
...
...
@@ -208,10 +224,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -219,20 +238,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
return
{
input_lin_q_results
,
...
...
@@ -312,10 +333,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -323,20 +342,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -344,17 +368,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
...
...
@@ -374,6 +403,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
...
...
@@ -394,6 +426,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -433,6 +468,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
);
// Matmul1 Dgrad2
...
...
@@ -453,10 +491,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Input Linear Q Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -464,21 +505,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Q Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -486,20 +531,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear KV Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -507,21 +557,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear KV Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -529,18 +583,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
return
{
input_q_grads
,
...
...
@@ -551,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace encdec
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp
→
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add
_cpp
.cpp
View file @
9f899769
...
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd(
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec_norm_add
::
c
u
blas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec_norm_add
::
c
u
blas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec_norm_add
::
ro
cblas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec_norm_add
::
ro
cblas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
9f899769
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
@@ -21,7 +25,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
...
...
@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
...
...
@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
c_type
,
output_lin_q_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
q_lin_results_ptr
,
d_type
,
output_lin_q_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear KV Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
...
...
@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
CUDA_R_16F
,
c_type
,
output_lin_kv_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
k_lin_results_ptr
,
d_type
,
output_lin_kv_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
...
...
@@ -168,7 +180,10 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
...
...
@@ -230,11 +245,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -242,19 +260,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
...
@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q
);
}
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
lyr_nrm_results
,
lyr_nrm_mean
,
...
...
@@ -366,9 +386,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
...
...
@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
...
...
@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
...
...
@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
);
// Matmul1 Dgrad2
...
...
@@ -515,11 +552,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
// Input Linear Q Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear Q Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear KV Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear KV Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_q_grads
,
...
...
@@ -635,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/layer_norm.h
View file @
9f899769
...
...
@@ -4,6 +4,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
...
...
@@ -84,9 +85,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
...
...
@@ -122,8 +123,8 @@ void cuWelfordMuSigma2(
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
,
32
);
}
}
}
...
...
@@ -180,9 +181,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
...
...
@@ -218,8 +219,8 @@ void cuWelfordMuSigma2(
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
,
32
);
}
}
}
...
...
@@ -227,9 +228,19 @@ void cuWelfordMuSigma2(
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
//template<> float rsqrt(float v) {
// return rsqrtf(v);
//}
#if defined __HIP_PLATFORM_HCC__
__device__
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#else
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#endif
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
...
...
@@ -290,7 +301,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
...
...
@@ -529,7 +540,7 @@ void cuComputeGradInput(
const
T
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
...
...
@@ -574,8 +585,8 @@ void cuComputeGradInput(
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
,
32
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
,
32
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
...
...
apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp
→
apex/contrib/csrc/multihead_attn/masked_softmax_dropout
_cpp
.cpp
View file @
9f899769
File moved
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
View file @
9f899769
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask
_cpp
.cpp
View file @
9f899769
...
...
@@ -4,7 +4,7 @@
namespace
multihead_attn
{
namespace
self_bias_additive_mask
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
);
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias_additive_mask
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias_additive_mask
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias_additive_mask
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias_additive_mask
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
9f899769
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
//#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
...
...
@@ -21,7 +24,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
self_bias_additive_mask
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -48,8 +51,8 @@ std::vector<torch::Tensor> fwd_cuda(
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
...
...
@@ -82,10 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -93,18 +95,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
...
...
@@ -123,7 +130,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
is_training
)
{
...
...
@@ -168,12 +179,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
outputs
.
copy_
(
output_biases
);
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -181,20 +195,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
return
{
input_lin_results
,
...
...
@@ -264,10 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -275,19 +289,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -295,17 +315,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
...
...
@@ -326,8 +351,11 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
...
...
@@ -346,6 +374,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -362,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_n
,
...
...
@@ -381,8 +412,11 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
...
...
@@ -401,9 +435,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -411,22 +449,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -434,20 +475,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_grads
,
...
...
@@ -458,6 +503,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias
_cpp
.cpp
View file @
9f899769
...
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
self_bias
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
9f899769
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
@@ -21,7 +24,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
self_bias
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -80,11 +83,10 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -92,17 +94,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
...
...
@@ -122,7 +129,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
...
...
@@ -180,12 +191,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
outputs
.
copy_
(
output_biases
);
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -193,20 +207,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
return
{
input_lin_results
,
...
...
@@ -275,10 +291,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -286,19 +300,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -306,17 +326,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
...
...
@@ -337,6 +362,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
...
...
@@ -357,6 +385,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -385,7 +416,10 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
...
...
@@ -408,10 +442,13 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -419,22 +456,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -442,20 +482,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_grads
,
...
...
@@ -466,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn
_cpp
.cpp
View file @
9f899769
...
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
self
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
);
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace self
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
9f899769
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
@@ -21,7 +24,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
self
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -78,9 +81,8 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -88,17 +90,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
...
...
@@ -118,6 +125,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
...
...
@@ -179,10 +189,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -190,19 +203,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
return
{
input_lin_results
,
...
...
@@ -270,11 +286,9 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -282,20 +296,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -303,17 +322,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
...
...
@@ -333,6 +357,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
...
...
@@ -353,6 +380,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -392,6 +422,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Matmul1 Dgrad2
...
...
@@ -411,11 +444,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -423,20 +459,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -444,18 +485,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
return
{
input_grads
,
...
...
@@ -464,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add
_cpp
.cpp
View file @
9f899769
...
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
self_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd(
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_norm_add
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_norm_add
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_norm_add
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_norm_add
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
9f899769
#include <vector>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
@@ -21,7 +25,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
self_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
...
...
@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
CUDA_R_16F
,
c_type
,
output_lin_dim
,
q_lin_results_ptr
,
d_type
,
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
compute_type
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
...
...
@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
...
...
@@ -202,11 +214,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -214,18 +229,24 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
...
@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens
);
}
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
lyr_nrm_results
,
lyr_nrm_mean
,
...
...
@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
...
...
@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
compute_type
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -367,18 +389,23 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_t
,
...
...
@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
...
...
@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -455,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
...
...
@@ -475,11 +511,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
&
alpha
),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
compute_type
,
algo
,
solution_index
,
flags
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
input_grads
,
...
...
@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self_norm_add
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/softmax.h
View file @
9f899769
...
...
@@ -11,7 +11,14 @@
#include <cuda_fp16.h>
#include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define APEX_WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
...
...
@@ -127,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -152,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -351,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -375,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
philox_args
);
...
...
@@ -505,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -529,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
curandStatePhilox4_32_10_t
state
;
...
...
@@ -765,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -790,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1020,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1045,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1243,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1268,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1385,7 +1392,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
return
false
;
}
int
log2_ceil_native
(
int
value
)
{
static
int
log2_ceil_native
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
...
...
@@ -1394,7 +1401,7 @@ int log2_ceil_native(int value) {
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 9000
&& !defined(__HIP_PLATFORM_HCC__)
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
...
...
@@ -1835,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1860,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -2305,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -2516,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment