Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
1436a66a
Commit
1436a66a
authored
Dec 02, 2021
by
hubertlu-tw
Browse files
Merge remote-tracking branch 'origin/master' into IFU-master-2021-10-15
parents
aee9f00d
08e88b1b
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
703 additions
and
389 deletions
+703
-389
.gitignore
.gitignore
+3
-0
README.md
README.md
+4
-4
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
+1
-1
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
...rc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
...contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+125
-67
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
...src/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
+3
-3
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+129
-72
apex/contrib/csrc/multihead_attn/layer_norm.h
apex/contrib/csrc/multihead_attn/layer_norm.h
+25
-14
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
+5
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
...ihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+105
-60
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
...trib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+96
-52
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+91
-46
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp
.../csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp
+3
-3
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+96
-49
No files found.
.gitignore
View file @
1436a66a
...
@@ -145,3 +145,6 @@ dmypy.json
...
@@ -145,3 +145,6 @@ dmypy.json
# Cython debug symbols
# Cython debug symbols
cython_debug/
cython_debug/
*.hip
*_hip.*
*hip*
README.md
View file @
1436a66a
...
@@ -129,18 +129,18 @@ Note: Pytorch version recommended is >=1.5 for extension build.
...
@@ -129,18 +129,18 @@ Note: Pytorch version recommended is >=1.5 for extension build.
### To install using python only build use the following command in apex folder:
### To install using python only build use the following command in apex folder:
```
```
python
3.6
setup.py install
python setup.py install
```
```
### To install using extensions enabled use the following command in apex folder:
### To install using extensions enabled use the following command in apex folder:
```
```
python
3.6
setup.py install --cpp_ext --cuda_ext
python setup.py install --cpp_ext --cuda_ext
```
```
### To install Apex on ROCm using ninja and without cloning the source
### To install Apex on ROCm using ninja and without cloning the source
```
```
pip
3.6
install ninja
pip install ninja
pip
3.6
install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
```
```
### Linux
### Linux
...
...
apex/contrib/csrc/layer_norm/ln_fwd_cuda_kernel.cu
View file @
1436a66a
...
@@ -183,4 +183,4 @@ void ln_fwd_cuda(
...
@@ -183,4 +183,4 @@ void ln_fwd_cuda(
assert
(
false
&&
"Not implemented"
);
assert
(
false
&&
"Not implemented"
);
}
}
}
}
\ No newline at end of file
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp
→
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout
_cpp
.cpp
View file @
1436a66a
File moved
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
View file @
1436a66a
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp
→
apex/contrib/csrc/multihead_attn/encdec_multihead_attn
_cpp
.cpp
View file @
1436a66a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec
{
namespace
encdec
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
...
@@ -146,11 +146,11 @@ std::vector<torch::Tensor> bwd(
);
);
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace encdec
}
// end namespace encdec
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
c
u
blas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
ro
cblas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
c
u
blas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
ro
cblas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
}
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
1436a66a
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -22,7 +25,7 @@ extern THCState *state;
...
@@ -22,7 +25,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec
{
namespace
encdec
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -86,9 +89,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -86,9 +89,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Q Fwd
// Input Linear Q Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_q_dim
,
output_lin_q_dim
,
...
@@ -96,20 +99,25 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -96,20 +99,25 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// Input Linear KV Fwd
// Input Linear KV Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
output_lin_kv_dim
,
...
@@ -117,17 +125,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -117,17 +125,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
k_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
CUDA_R_32F
,
k_lin_results_ptr
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
@@ -146,6 +159,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -146,6 +159,9 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
...
@@ -208,10 +224,13 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -208,10 +224,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
// Output Linear
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -219,20 +238,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -219,20 +238,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
rocblas_datatype_f16_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
rocblas_datatype_f32_r
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
algo
,
solution_index
,
flags
));
return
{
return
{
input_lin_q_results
,
input_lin_q_results
,
...
@@ -312,10 +333,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -312,10 +333,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -323,20 +342,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -323,20 +342,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -344,17 +368,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -344,17 +368,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
@@ -374,6 +403,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -374,6 +403,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
...
@@ -394,6 +426,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -394,6 +426,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
...
@@ -433,6 +468,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -433,6 +468,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
...
@@ -453,10 +491,13 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -453,10 +491,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -464,21 +505,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -464,21 +505,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
rocblas_datatype_f16_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -486,20 +531,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -486,20 +531,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -507,21 +557,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -507,21 +557,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
algo
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
solution_index
,
flags
));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -529,18 +583,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -529,18 +583,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
solution_index
,
flags
));
return
{
return
{
input_q_grads
,
input_q_grads
,
...
@@ -551,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -551,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace encdec
}
// end namespace encdec
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp
→
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add
_cpp
.cpp
View file @
1436a66a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
encdec_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd(
...
@@ -192,7 +192,7 @@ std::vector<torch::Tensor> bwd(
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec_norm_add
::
c
u
blas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec_norm_add
::
ro
cblas_gemmex
::
fwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec_norm_add
::
c
u
blas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec_norm_add
::
ro
cblas_gemmex
::
bwd
,
"Encdec Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
1436a66a
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -21,7 +25,7 @@ extern THCState *state;
...
@@ -21,7 +25,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec_norm_add
{
namespace
encdec_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -95,7 +99,6 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Layer Norm
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
...
@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -109,7 +112,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
// Input Linear Q Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_q_dim
,
output_lin_q_dim
,
...
@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -117,21 +120,26 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
c_type
,
output_lin_q_dim
,
output_lin_q_dim
,
CUDA_R_32F
,
q_lin_results_ptr
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
output_lin_q_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear KV Fwd
// Input Linear KV Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
output_lin_kv_dim
,
...
@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -139,18 +147,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
k_lin_results_ptr
,
CUDA_R_16F
,
c_type
,
output_lin_kv_dim
,
output_lin_kv_dim
,
CUDA_R_32F
,
k_lin_results_ptr
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
output_lin_kv_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
...
@@ -168,7 +180,10 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -168,7 +180,10 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
beta
,
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Padded Softmax
// Padded Softmax
...
@@ -230,11 +245,14 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -230,11 +245,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
// Output Linear
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -242,19 +260,23 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -242,19 +260,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
d_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// End-of-block Dropout-Add
// End-of-block Dropout-Add
if
(
is_training
)
{
if
(
is_training
)
{
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -272,8 +294,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q
);
total_tokens_q
);
}
}
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
return
{
lyr_nrm_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_mean
,
...
@@ -366,9 +386,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -366,9 +386,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
// Dropout Add Backward
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
...
@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -378,7 +396,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -386,20 +404,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -407,17 +430,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -437,6 +465,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
...
@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -457,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
...
@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -496,6 +530,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
q_lin_grads_ptr
,
lead_dim_q
,
batch_stride_q
,
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
...
@@ -515,11 +552,14 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -515,11 +552,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
k_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -527,22 +567,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
d_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -550,20 +594,25 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -571,21 +620,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
d_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -593,17 +646,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -622,7 +680,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
return
{
input_q_grads
,
input_q_grads
,
...
@@ -635,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -635,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/layer_norm.h
View file @
1436a66a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
template
<
typename
U
>
__device__
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
void
cuWelfordOnlineSum
(
const
U
curr
,
const
U
curr
,
...
@@ -84,9 +85,9 @@ void cuWelfordMuSigma2(
...
@@ -84,9 +85,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
// threadIdx.x == 0 has correct values for each warp
// threadIdx.x == 0 has correct values for each warp
...
@@ -122,8 +123,8 @@ void cuWelfordMuSigma2(
...
@@ -122,8 +123,8 @@ void cuWelfordMuSigma2(
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
,
32
);
}
}
}
}
}
}
...
@@ -180,9 +181,9 @@ void cuWelfordMuSigma2(
...
@@ -180,9 +181,9 @@ void cuWelfordMuSigma2(
// intra-warp reductions
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
,
32
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
,
32
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
,
32
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
}
// threadIdx.x == 0 has correct values for each warp
// threadIdx.x == 0 has correct values for each warp
...
@@ -218,8 +219,8 @@ void cuWelfordMuSigma2(
...
@@ -218,8 +219,8 @@ void cuWelfordMuSigma2(
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
// don't care about final value of count, we know count == n2
}
else
{
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
mu
=
WARP_SHFL
(
mu
,
0
,
32
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
,
32
);
}
}
}
}
}
}
...
@@ -227,9 +228,19 @@ void cuWelfordMuSigma2(
...
@@ -227,9 +228,19 @@ void cuWelfordMuSigma2(
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
return
U
(
1
)
/
sqrt
(
v
);
}
}
//template<> float rsqrt(float v) {
// return rsqrtf(v);
//}
#if defined __HIP_PLATFORM_HCC__
__device__
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#else
template
<
>
float
rsqrt
(
float
v
)
{
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
return
rsqrtf
(
v
);
}
}
#endif
template
<
>
double
rsqrt
(
double
v
)
{
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
return
rsqrt
(
v
);
}
}
...
@@ -290,7 +301,7 @@ void cuApplyLayerNorm(
...
@@ -290,7 +301,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
// 2) Tensors are contiguous
//
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
U
mu
,
sigma2
;
...
@@ -529,7 +540,7 @@ void cuComputeGradInput(
...
@@ -529,7 +540,7 @@ void cuComputeGradInput(
const
T
*
gamma
,
const
T
*
gamma
,
T
*
grad_input
)
T
*
grad_input
)
{
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_mean
=
mean
[
i1
];
...
@@ -574,8 +585,8 @@ void cuComputeGradInput(
...
@@ -574,8 +585,8 @@ void cuComputeGradInput(
}
}
// intra-warp reductions
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
,
32
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
,
32
);
}
}
// inter-warp reductions
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
if
(
blockDim
.
y
>
1
)
{
...
...
apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp
→
apex/contrib/csrc/multihead_attn/masked_softmax_dropout
_cpp
.cpp
View file @
1436a66a
File moved
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
View file @
1436a66a
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask
_cpp
.cpp
View file @
1436a66a
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self_bias_additive_mask
{
namespace
self_bias_additive_mask
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
...
@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
);
);
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias_additive_mask
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias_additive_mask
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias_additive_mask
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias_additive_mask
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
1436a66a
#include <vector>
#include <vector>
#include <math.h>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#include <ATen/ATen.h>
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "strided_batched_gemm.h"
#include "softmax.h"
#include "softmax.h"
...
@@ -21,7 +24,7 @@ extern THCState *state;
...
@@ -21,7 +24,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self_bias_additive_mask
{
namespace
self_bias_additive_mask
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -48,8 +51,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -48,8 +51,8 @@ std::vector<torch::Tensor> fwd_cuda(
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// There is no reason to use more than one stream as every kernel is
...
@@ -82,10 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -82,10 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -93,18 +95,23 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -93,18 +95,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
CUDA_R_32F
,
q_lin_results_ptr
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
...
@@ -123,7 +130,11 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -123,7 +130,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
bmm1_results_ptr
),
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
if
(
is_training
)
{
if
(
is_training
)
{
...
@@ -168,12 +179,15 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -168,12 +179,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -181,20 +195,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -181,20 +195,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
rocblas_datatype_f16_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
rocblas_datatype_f32_r
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
algo
,
solution_index
,
flags
));
return
{
return
{
input_lin_results
,
input_lin_results
,
...
@@ -264,10 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -264,10 +280,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -275,19 +289,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -275,19 +289,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -295,17 +315,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -295,17 +315,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
...
@@ -326,8 +351,11 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -326,8 +351,11 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -346,6 +374,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -346,6 +374,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
...
@@ -362,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -362,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -381,8 +412,11 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -381,8 +412,11 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -401,9 +435,13 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -401,9 +435,13 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Input Linear Dgrad
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -411,22 +449,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -411,22 +449,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
//static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r
,
CUDA_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
rocblas_datatype_f16_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -434,20 +475,24 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -434,20 +475,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
return
{
input_grads
,
input_grads
,
...
@@ -458,6 +503,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -458,6 +503,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias
_cpp
.cpp
View file @
1436a66a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self_bias
{
namespace
self_bias
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
...
@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
);
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
1436a66a
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -21,7 +24,7 @@ extern THCState *state;
...
@@ -21,7 +24,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self_bias
{
namespace
self_bias
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -80,11 +83,10 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -80,11 +83,10 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -92,17 +94,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -92,17 +94,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
CUDA_R_32F
,
q_lin_results_ptr
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
@@ -122,7 +129,11 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -122,7 +129,11 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
if
(
pad_mask
==
nullptr
)
{
...
@@ -180,12 +191,15 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -180,12 +191,15 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -193,20 +207,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -193,20 +207,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
rocblas_datatype_f16_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
rocblas_datatype_f32_r
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
algo
,
solution_index
,
flags
));
return
{
return
{
input_lin_results
,
input_lin_results
,
...
@@ -275,10 +291,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -275,10 +291,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -286,19 +300,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -286,19 +300,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -306,17 +326,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -306,17 +326,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
...
@@ -337,6 +362,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -337,6 +362,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
...
@@ -357,6 +385,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -357,6 +385,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
...
@@ -385,7 +416,10 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -385,7 +416,10 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
...
@@ -408,10 +442,13 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -408,10 +442,13 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// Input Linear Dgrad
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -419,22 +456,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -419,22 +456,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
//static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r
,
CUDA_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
rocblas_datatype_f16_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -442,20 +482,24 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -442,20 +482,24 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
return
{
input_grads
,
input_grads
,
...
@@ -466,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -466,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn
_cpp
.cpp
View file @
1436a66a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self
{
namespace
self
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
...
@@ -121,12 +121,12 @@ std::vector<torch::Tensor> bwd(
);
);
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
}
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
1436a66a
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -21,7 +24,7 @@ extern THCState *state;
...
@@ -21,7 +24,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self
{
namespace
self
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -78,9 +81,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -78,9 +81,8 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Input Linear Fwd
// Input Linear Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -88,17 +90,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -88,17 +90,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
@@ -118,6 +125,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -118,6 +125,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Padded Softmax
// Padded Softmax
...
@@ -179,10 +189,13 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -179,10 +189,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
// Output Linear
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -190,19 +203,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -190,19 +203,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
return
{
return
{
input_lin_results
,
input_lin_results
,
...
@@ -270,11 +286,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -270,11 +286,9 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -282,20 +296,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -282,20 +296,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -303,17 +322,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -303,17 +322,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
@@ -333,6 +357,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -333,6 +357,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
...
@@ -353,6 +380,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -353,6 +380,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
...
@@ -392,6 +422,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -392,6 +422,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
...
@@ -411,11 +444,14 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -411,11 +444,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// Input Linear Dgrad
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -423,20 +459,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -423,20 +459,25 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -444,18 +485,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -444,18 +485,22 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
rocblas_datatype_f32_r
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
solution_index
,
flags
));
return
{
return
{
input_grads
,
input_grads
,
...
@@ -464,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -464,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add
_cpp
.cpp
View file @
1436a66a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self_norm_add
{
namespace
self_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd(
...
@@ -167,7 +167,7 @@ std::vector<torch::Tensor> bwd(
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_norm_add
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_norm_add
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_norm_add
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_norm_add
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention Plus Layer Norm and Residual Add Backward."
);
}
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
1436a66a
#include <vector>
#include <vector>
#include <iostream>
#include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -21,7 +25,7 @@ extern THCState *state;
...
@@ -21,7 +25,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self_norm_add
{
namespace
self_norm_add
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -88,7 +92,7 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
// Layer Norm
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
HostApplyLayerNorm
<
at
::
Half
,
float
>
(
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
at
::
Half
*>
(
lyr_nrm_results
.
data_ptr
()),
...
@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -102,7 +106,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
// Input Linear Fwd
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -110,18 +114,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
CUDA_R_16F
,
c_type
,
output_lin_dim
,
q_lin_results_ptr
,
d_type
,
output_lin_dim
,
output_lin_dim
,
CUDA_R_32F
,
compute_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -141,6 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Padded Softmax
// Padded Softmax
...
@@ -202,11 +214,14 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -202,11 +214,14 @@ std::vector<torch::Tensor> fwd_cuda(
beta
,
beta
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
attn_batches
);
// Output Linear
// Output Linear
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -214,18 +229,24 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -214,18 +229,24 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// End-of-block Dropout-Add
// End-of-block Dropout-Add
if
(
is_training
)
{
if
(
is_training
)
{
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_dropout_add_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -243,8 +264,6 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens
);
total_tokens
);
}
}
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
return
{
lyr_nrm_results
,
lyr_nrm_results
,
lyr_nrm_mean
,
lyr_nrm_mean
,
...
@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -327,8 +346,6 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
// Dropout Add Backward
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
at
::
Half
const
*>
(
output_grads
.
data_ptr
()),
...
@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -338,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
// Output Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -346,20 +363,25 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
compute_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -367,18 +389,23 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -367,18 +389,23 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
CUDA_R_16F
,
b_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
...
@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -397,6 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
...
@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -417,6 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
...
@@ -455,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -455,6 +488,9 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
q_lin_grads_ptr
,
q_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
...
@@ -475,11 +511,14 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -475,11 +511,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
attn_batches
);
// Input Linear Dgrad
// Input Linear Dgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -487,22 +526,26 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_grads.data_ptr()),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
d_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
THCublasCheck
(
c
u
blas
G
emm
E
x
(
handle
,
THCublasCheck
(
ro
cblas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -511,17 +554,22 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
//static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
CUDA_R_16F
,
a_type
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
b_type
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
c_type
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
embed_dim
,
CUDA_R_32F
,
compute_type
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
algo
,
solution_index
,
flags
));
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -540,7 +588,6 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
static_cast
<
half
*>
(
lyr_nrm_beta_grads
.
data_ptr
())
);
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
return
{
return
{
input_grads
,
input_grads
,
...
@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -551,6 +598,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self_norm_add
}
// end namespace self_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment