Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4e7a2a8e
Commit
4e7a2a8e
authored
Nov 28, 2023
by
flyingdown
Browse files
fix up for torch2.1
parent
2a4864d5
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
382 additions
and
574 deletions
+382
-574
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+65
-101
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+65
-101
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+44
-68
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+45
-69
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+44
-68
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+44
-68
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+27
-28
apex/contrib/test/run_rocm_extensions.py
apex/contrib/test/run_rocm_extensions.py
+1
-0
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+21
-33
csrc/mlp_cuda.cu
csrc/mlp_cuda.cu
+21
-33
tests/distributed/DDP/ddp_race_condition_test.py
tests/distributed/DDP/ddp_race_condition_test.py
+1
-1
tests/distributed/amp_master_params/amp_master_params.py
tests/distributed/amp_master_params/amp_master_params.py
+1
-1
tests/distributed/synced_batchnorm/test_groups.py
tests/distributed/synced_batchnorm/test_groups.py
+1
-1
tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py
...ted/synced_batchnorm/two_gpu_test_different_batch_size.py
+1
-1
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
+1
-1
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
4e7a2a8e
...
@@ -85,12 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -85,12 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_q_dim
,
output_lin_q_dim
,
...
@@ -98,25 +98,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -98,25 +98,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_q_dim
,
output_lin_q_dim
,
q_lin_results_ptr
,
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_q_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Fwd
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
output_lin_kv_dim
,
...
@@ -124,22 +120,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -124,22 +120,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_kv_dim
,
output_lin_kv_dim
,
k_lin_results_ptr
,
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_kv_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -219,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -219,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -227,22 +219,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -227,22 +219,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_q_results
,
return
{
input_lin_q_results
,
...
@@ -318,7 +306,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -318,7 +306,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
...
@@ -332,7 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -332,7 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -340,25 +328,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -340,25 +328,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -366,22 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -366,22 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -493,7 +473,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -493,7 +473,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -501,25 +481,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -501,25 +481,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -527,25 +503,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -527,25 +503,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -553,25 +525,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -553,25 +525,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -579,22 +547,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -579,22 +547,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_q_grads
,
input_q_grads
,
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
4e7a2a8e
...
@@ -101,7 +101,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -101,7 +101,7 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
// Layer Norm
...
@@ -116,7 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -116,7 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_q_dim
,
output_lin_q_dim
,
...
@@ -124,26 +124,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -124,26 +124,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
q_lin_results_ptr
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_q_dim
,
));
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Fwd
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
output_lin_kv_dim
,
...
@@ -151,22 +147,18 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -151,22 +147,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
k_lin_results_ptr
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_kv_dim
,
));
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -246,7 +238,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -246,7 +238,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -254,22 +246,18 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -254,22 +246,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// End-of-block Dropout-Add
// End-of-block Dropout-Add
if
(
is_training
)
{
if
(
is_training
)
{
...
@@ -374,7 +362,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -374,7 +362,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
...
@@ -396,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -396,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -404,25 +392,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -404,25 +392,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -430,22 +414,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -430,22 +414,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -557,7 +537,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -557,7 +537,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -565,26 +545,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -565,26 +545,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -592,25 +568,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -592,25 +568,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -618,25 +590,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -618,25 +590,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -644,22 +612,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -644,22 +612,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
4e7a2a8e
...
@@ -80,13 +80,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -80,13 +80,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -94,22 +94,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -94,22 +94,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -183,7 +179,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -183,7 +179,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -191,22 +187,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -191,22 +187,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
...
@@ -267,7 +259,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -267,7 +259,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
...
@@ -281,7 +273,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -281,7 +273,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -289,25 +281,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -289,25 +281,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -315,22 +303,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -315,22 +303,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
...
@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
4e7a2a8e
...
@@ -78,13 +78,13 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -78,13 +78,13 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -92,22 +92,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -92,22 +92,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -189,7 +185,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -189,7 +185,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -197,22 +193,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -197,22 +193,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
@@ -273,7 +265,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -273,7 +265,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
...
@@ -287,7 +279,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -287,7 +279,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -295,25 +287,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -295,25 +287,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -321,22 +309,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -321,22 +309,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
...
@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
,
attn_batches
,
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
4e7a2a8e
...
@@ -77,12 +77,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -77,12 +77,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -90,22 +90,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -90,22 +90,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -185,7 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -185,7 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -193,22 +189,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -193,22 +189,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
@@ -269,7 +261,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -269,7 +261,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
...
@@ -283,7 +275,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -283,7 +275,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -291,25 +283,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -291,25 +283,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -317,22 +305,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -317,22 +305,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
HIPBLAS_R_32F
,
rocblas_datatype_f16_r
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -444,7 +428,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -444,7 +428,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -452,25 +436,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -452,25 +436,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -478,22 +458,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -478,22 +458,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
4e7a2a8e
...
@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
// Layer Norm
...
@@ -103,7 +103,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -103,7 +103,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -111,23 +111,19 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -111,23 +111,19 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
output_lin_dim
,
));
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -208,7 +204,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -208,7 +204,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
flags
);
// Output Linear
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -216,22 +212,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -216,22 +212,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// End-of-block Dropout-Add
// End-of-block Dropout-Add
...
@@ -320,7 +312,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -320,7 +312,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
...
@@ -341,7 +333,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -341,7 +333,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -349,25 +341,21 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -349,25 +341,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Output Linear Wgrad
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -375,22 +363,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -375,22 +363,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
embed_dim
,
));
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
gemm_switch_fp32accum
(
a_layout_t
,
...
@@ -502,7 +486,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -502,7 +486,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
flags
);
// Input Linear Dgrad
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -510,26 +494,22 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -510,26 +494,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_grads.data_ptr()),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Input Linear Wgrad
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
...
@@ -538,22 +518,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -538,22 +518,18 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
//static_cast<const void*>(inputs.data_ptr()),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
HIPBLAS_R_32F
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
0
/*solution_index*/
,
));
flags
));
// Fused Layer Norm Bwd with Residual Add
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
HostLayerNormGradient
<
half
,
float
>
(
...
...
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
View file @
4e7a2a8e
...
@@ -17,15 +17,15 @@
...
@@ -17,15 +17,15 @@
// symbol to be automatically resolved by PyTorch libs
// symbol to be automatically resolved by PyTorch libs
/*
/*
rocblas_datatype a_type =
rocblas_datatype_f16_r
; // OK
rocblas_datatype a_type =
HIPBLAS_R_16F
; // OK
rocblas_datatype b_type =
rocblas_datatype_f16_r
; // OK
rocblas_datatype b_type =
HIPBLAS_R_16F
; // OK
rocblas_datatype c_type =
rocblas_datatype_f16_r
; // OK
rocblas_datatype c_type =
HIPBLAS_R_16F
; // OK
rocblas_datatype d_type =
rocblas_datatype_f16_r
;
rocblas_datatype d_type =
HIPBLAS_R_16F
;
rocblas_datatype compute_type =
rocblas_datatype_f32_r
;
rocblas_datatype compute_type =
HIPBLAS_R_32F
;
rocblas_gemm_algo algo =
rocblas_gemm_algo_standard
;
rocblas_gemm_algo algo =
HIPBLAS_GEMM_DEFAULT
;
int32_t solution_index = 0;
int32_t solution_index = 0;
rocblas_
int flags = 0;
int flags = 0;
*/
*/
namespace
{
namespace
{
...
@@ -44,38 +44,37 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
...
@@ -44,38 +44,37 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
roc
blas
_g
emm
_a
lgo
algo
,
rocblas_
int
flags
)
{
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
hip
blas
G
emm
A
lgo
_t
algo
,
int
flags
)
{
cu
blasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
hip
blasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cu
blasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
hip
blasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cu
blasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
hip
blasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cuda
Stream_t
stream
=
at
::
cuda
::
getCurrent
CUDA
Stream
().
stream
();
hip
Stream_t
stream
=
at
::
hip
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
().
stream
();
cu
blasSetStream
(
handle
,
stream
);
hip
blasSetStream
(
handle
,
stream
);
float
fAlpha
=
alpha
;
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
float
fBeta
=
beta
;
//THCublasCheck(
cu
blasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//THCublasCheck(
hip
blasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_s
trided
_b
atched
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
S
trided
B
atched
E
x
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
(
void
*
)
&
fAlpha
,
(
const
void
*
)
a
,
HIPBLAS_R_16F
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
const
void
*
)
b
,
HIPBLAS_R_16F
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
(
void
*
)
&
fBeta
,
(
void
*
)
c
,
HIPBLAS_R_16F
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
HIPBLAS_R_32F
/*compute_type*/
,
algo
));
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
}
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_
int
flags
)
{
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
int
flags
)
{
auto
stream
=
c10
::
cuda
::
getCurrent
CUDA
Stream
();
auto
stream
=
c10
::
hip
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
}
else
{
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
...
...
apex/contrib/test/run_rocm_extensions.py
View file @
4e7a2a8e
...
@@ -4,6 +4,7 @@ import sys
...
@@ -4,6 +4,7 @@ import sys
test_dirs
=
[
"groupbn"
,
"fused_dense"
,
"layer_norm"
,
"multihead_attn"
,
"transducer"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
test_dirs
=
[
"groupbn"
,
"fused_dense"
,
"layer_norm"
,
"multihead_attn"
,
"transducer"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
ROCM_BLACKLIST
=
[
ROCM_BLACKLIST
=
[
"groupbn"
,
"layer_norm"
"layer_norm"
]
]
...
...
csrc/fused_dense_cuda.cu
View file @
4e7a2a8e
...
@@ -31,7 +31,7 @@ cublasStatus_t gemm_bias(
...
@@ -31,7 +31,7 @@ cublasStatus_t gemm_bias(
double
*
C
,
double
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
handle
,
transa
,
transa
,
transb
,
transb
,
...
@@ -40,22 +40,18 @@ cublasStatus_t gemm_bias(
...
@@ -40,22 +40,18 @@ cublasStatus_t gemm_bias(
k
,
k
,
alpha
,
alpha
,
A
,
A
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldb
,
ldb
,
beta
,
beta
,
C
,
C
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldc
,
ldc
,
C
,
HIPBLAS_R_64F
,
rocblas_datatype_f64_r
,
HIPBLAS_GEMM_DEFAULT
ldc
,
);
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
@@ -97,7 +93,7 @@ cublasStatus_t gemm_bias(
...
@@ -97,7 +93,7 @@ cublasStatus_t gemm_bias(
float
*
C
,
float
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
handle
,
transa
,
transa
,
transb
,
transb
,
...
@@ -106,22 +102,18 @@ cublasStatus_t gemm_bias(
...
@@ -106,22 +102,18 @@ cublasStatus_t gemm_bias(
k
,
k
,
alpha
,
alpha
,
A
,
A
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldb
,
ldb
,
beta
,
beta
,
C
,
C
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldc
,
C
,
rocblas_datatype_f32_r
,
ldc
,
ldc
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
,
HIPBLAS_GEMM_DEFAULT
0
,
);
0
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
...
@@ -164,7 +156,7 @@ cublasStatus_t gemm_bias(
...
@@ -164,7 +156,7 @@ cublasStatus_t gemm_bias(
at
::
Half
*
C
,
at
::
Half
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
handle
,
transa
,
transa
,
transb
,
transb
,
...
@@ -173,22 +165,18 @@ cublasStatus_t gemm_bias(
...
@@ -173,22 +165,18 @@ cublasStatus_t gemm_bias(
k
,
k
,
alpha
,
alpha
,
A
,
A
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldb
,
ldb
,
beta
,
beta
,
C
,
C
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldc
,
C
,
rocblas_datatype_f16_r
,
ldc
,
ldc
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
,
HIPBLAS_GEMM_DEFAULT
0
,
);
0
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
...
csrc/mlp_cuda.cu
View file @
4e7a2a8e
...
@@ -78,7 +78,7 @@ cublasStatus_t mlp_gemm(
...
@@ -78,7 +78,7 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
handle
,
transa
,
transa
,
transb
,
transb
,
...
@@ -87,22 +87,18 @@ cublasStatus_t mlp_gemm(
...
@@ -87,22 +87,18 @@ cublasStatus_t mlp_gemm(
k
,
k
,
alpha
,
alpha
,
A
,
A
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldb
,
ldb
,
beta
,
beta
,
C
,
C
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldc
,
ldc
,
C
,
HIPBLAS_R_64F
,
rocblas_datatype_f64_r
,
HIPBLAS_GEMM_DEFAULT
ldc
,
);
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
@@ -145,7 +141,7 @@ cublasStatus_t mlp_gemm(
...
@@ -145,7 +141,7 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
handle
,
transa
,
transa
,
transb
,
transb
,
...
@@ -154,22 +150,18 @@ cublasStatus_t mlp_gemm(
...
@@ -154,22 +150,18 @@ cublasStatus_t mlp_gemm(
k
,
k
,
alpha
,
alpha
,
A
,
A
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldb
,
ldb
,
beta
,
beta
,
C
,
C
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldc
,
C
,
rocblas_datatype_f32_r
,
ldc
,
ldc
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
,
HIPBLAS_GEMM_DEFAULT
0
,
);
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
...
@@ -213,7 +205,7 @@ cublasStatus_t mlp_gemm(
...
@@ -213,7 +205,7 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
handle
,
transa
,
transa
,
transb
,
transb
,
...
@@ -222,22 +214,18 @@ cublasStatus_t mlp_gemm(
...
@@ -222,22 +214,18 @@ cublasStatus_t mlp_gemm(
k
,
k
,
alpha
,
alpha
,
A
,
A
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldb
,
ldb
,
beta
,
beta
,
C
,
C
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldc
,
C
,
rocblas_datatype_f16_r
,
ldc
,
ldc
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
rocblas_gemm_algo_standard
,
HIPBLAS_GEMM_DEFAULT
0
,
);
flag
);
#else
#else
return
cublasGemmEx
(
return
cublasGemmEx
(
handle
,
handle
,
...
...
tests/distributed/DDP/ddp_race_condition_test.py
View file @
4e7a2a8e
...
@@ -8,7 +8,7 @@ import os
...
@@ -8,7 +8,7 @@ import os
parser
=
argparse
.
ArgumentParser
(
description
=
'allreduce hook example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'allreduce hook example'
)
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
distributed
=
False
args
.
distributed
=
False
...
...
tests/distributed/amp_master_params/amp_master_params.py
View file @
4e7a2a8e
...
@@ -8,7 +8,7 @@ from apex.parallel import DistributedDataParallel
...
@@ -8,7 +8,7 @@ from apex.parallel import DistributedDataParallel
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch.
# automatically by torch.distributed.launch.
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--opt_level"
,
default
=
"O2"
,
type
=
str
)
parser
.
add_argument
(
"--opt_level"
,
default
=
"O2"
,
type
=
str
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
tests/distributed/synced_batchnorm/test_groups.py
View file @
4e7a2a8e
...
@@ -26,7 +26,7 @@ batch_size = 32
...
@@ -26,7 +26,7 @@ batch_size = 32
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
apex.parallel
import
DistributedDataParallel
as
DDP
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp64"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp64"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--group_size"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--group_size"
,
default
=
0
,
type
=
int
)
...
...
tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py
View file @
4e7a2a8e
...
@@ -23,7 +23,7 @@ def compare(desc, inp1, inp2, error= 1e-5):
...
@@ -23,7 +23,7 @@ def compare(desc, inp1, inp2, error= 1e-5):
return
close
return
close
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--local
_
rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--local
-
rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--apex'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--apex'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
View file @
4e7a2a8e
...
@@ -26,7 +26,7 @@ batch_size = 32
...
@@ -26,7 +26,7 @@ batch_size = 32
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
apex.parallel
import
DistributedDataParallel
as
DDP
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp64"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp64"
,
action
=
'store_true'
,
default
=
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment