Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4e7a2a8e
Commit
4e7a2a8e
authored
Nov 28, 2023
by
flyingdown
Browse files
fix up for torch2.1
parent
2a4864d5
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
382 additions
and
574 deletions
+382
-574
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+65
-101
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+65
-101
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+44
-68
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+45
-69
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+44
-68
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+44
-68
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+27
-28
apex/contrib/test/run_rocm_extensions.py
apex/contrib/test/run_rocm_extensions.py
+1
-0
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+21
-33
csrc/mlp_cuda.cu
csrc/mlp_cuda.cu
+21
-33
tests/distributed/DDP/ddp_race_condition_test.py
tests/distributed/DDP/ddp_race_condition_test.py
+1
-1
tests/distributed/amp_master_params/amp_master_params.py
tests/distributed/amp_master_params/amp_master_params.py
+1
-1
tests/distributed/synced_batchnorm/test_groups.py
tests/distributed/synced_batchnorm/test_groups.py
+1
-1
tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py
...ted/synced_batchnorm/two_gpu_test_different_batch_size.py
+1
-1
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
+1
-1
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
4e7a2a8e
...
...
@@ -85,12 +85,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
...
...
@@ -98,25 +98,21 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
...
...
@@ -124,22 +120,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_kv_dim
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -219,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -227,22 +219,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_q_results
,
...
...
@@ -318,7 +306,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
...
...
@@ -332,7 +320,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -340,25 +328,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -366,22 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -493,7 +473,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -501,25 +481,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -527,25 +503,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -553,25 +525,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -579,22 +547,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_q_grads
,
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
4e7a2a8e
...
...
@@ -101,7 +101,7 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
...
...
@@ -116,7 +116,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_q_dim
,
...
...
@@ -124,26 +124,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
output_lin_q_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_q_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_kv_dim
,
...
...
@@ -151,22 +147,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
output_lin_kv_dim
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
...
...
@@ -246,7 +238,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -254,22 +246,18 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// End-of-block Dropout-Add
if
(
is_training
)
{
...
...
@@ -374,7 +362,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
...
...
@@ -396,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -404,25 +392,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -430,22 +414,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -557,7 +537,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -565,26 +545,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -592,25 +568,21 @@ std::vector<torch::Tensor> bwd_cuda(
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_q
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_q_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -618,25 +590,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -644,22 +612,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches_kv
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs_kv
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
k_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_kv_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
4e7a2a8e
...
...
@@ -80,13 +80,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -94,22 +94,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -183,7 +179,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -191,22 +187,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
bmm1_results
,
dropout_results
,
...
...
@@ -267,7 +259,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
...
...
@@ -281,7 +273,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -289,25 +281,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -315,22 +303,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
...
...
@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
4e7a2a8e
...
...
@@ -78,13 +78,13 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -92,22 +92,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -189,7 +185,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs
.
copy_
(
output_biases
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -197,22 +193,18 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
...
@@ -273,7 +265,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
...
...
@@ -287,7 +279,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -295,25 +287,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -321,22 +309,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
...
...
@@ -441,7 +425,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
,
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -449,25 +433,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -475,22 +455,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
4e7a2a8e
...
...
@@ -77,12 +77,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -90,22 +90,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -185,7 +181,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -193,22 +189,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
softmax_results
,
dropout_results
,
...
...
@@ -269,7 +261,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
...
...
@@ -283,7 +275,7 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -291,25 +283,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -317,22 +305,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -444,7 +428,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -452,25 +436,21 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -478,22 +458,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
4e7a2a8e
...
...
@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm
...
...
@@ -103,7 +103,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
const
at
::
Half
*>
(
lyr_nrm_beta_weights
.
data_ptr
()));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -111,23 +111,19 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
output_lin_dim
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -208,7 +204,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags
);
// Output Linear
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -216,22 +212,18 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// End-of-block Dropout-Add
...
...
@@ -320,7 +312,7 @@ std::vector<torch::Tensor> bwd_cuda(
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
rocblas_
int
flags
=
0
;
int
flags
=
0
;
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
#ifdef __HIP_PLATFORM_HCC__
...
...
@@ -341,7 +333,7 @@ std::vector<torch::Tensor> bwd_cuda(
(
1.0
/
(
1.0
-
dropout_prob
)));
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -349,25 +341,21 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -375,22 +363,18 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
...
...
@@ -502,7 +486,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags
);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -510,26 +494,22 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
//static_cast<void*>(input_grads.data_ptr()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
E
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -538,22 +518,18 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
&
alpha
),
//static_cast<const void*>(inputs.data_ptr()),
static_cast
<
const
void
*>
(
lyr_nrm_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
HIPBLAS_R_16F
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
/*b_type*/
,
HIPBLAS_R_16F
/*b_type*/
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
HIPBLAS_R_16F
/*c_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
HIPBLAS_R_32F
/*compute_type*/
,
HIPBLAS_GEMM_DEFAULT
/*algo*/
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
View file @
4e7a2a8e
...
...
@@ -17,15 +17,15 @@
// symbol to be automatically resolved by PyTorch libs
/*
rocblas_datatype a_type =
rocblas_datatype_f16_r
; // OK
rocblas_datatype b_type =
rocblas_datatype_f16_r
; // OK
rocblas_datatype c_type =
rocblas_datatype_f16_r
; // OK
rocblas_datatype d_type =
rocblas_datatype_f16_r
;
rocblas_datatype compute_type =
rocblas_datatype_f32_r
;
rocblas_datatype a_type =
HIPBLAS_R_16F
; // OK
rocblas_datatype b_type =
HIPBLAS_R_16F
; // OK
rocblas_datatype c_type =
HIPBLAS_R_16F
; // OK
rocblas_datatype d_type =
HIPBLAS_R_16F
;
rocblas_datatype compute_type =
HIPBLAS_R_32F
;
rocblas_gemm_algo algo =
rocblas_gemm_algo_standard
;
rocblas_gemm_algo algo =
HIPBLAS_GEMM_DEFAULT
;
int32_t solution_index = 0;
rocblas_
int flags = 0;
int flags = 0;
*/
namespace
{
...
...
@@ -44,38 +44,37 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
roc
blas
_g
emm
_a
lgo
algo
,
rocblas_
int
flags
)
{
cu
blasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cu
blasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
hip
blas
G
emm
A
lgo
_t
algo
,
int
flags
)
{
hip
blasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
hip
blasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cu
blasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cuda
Stream_t
stream
=
at
::
cuda
::
getCurrent
CUDA
Stream
().
stream
();
cu
blasSetStream
(
handle
,
stream
);
hip
blasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
hip
Stream_t
stream
=
at
::
hip
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
().
stream
();
hip
blasSetStream
(
handle
,
stream
);
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
//THCublasCheck(
cu
blasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_s
trided
_b
atched
_e
x
(
handle
,
//THCublasCheck(
hip
blasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
hip
blas
G
emm
S
trided
B
atched
E
x
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
rocblas_datatype_f32_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
(
void
*
)
&
fAlpha
,
(
const
void
*
)
a
,
HIPBLAS_R_16F
/*a_type*/
,
(
int
)
lda
,
strideA
,
(
const
void
*
)
b
,
HIPBLAS_R_16F
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
(
void
*
)
c
,
HIPBLAS_R_16F
/*c_type*/
,
(
int
)
ldc
,
strideC
,
(
int
)
batchCount
,
HIPBLAS_R_32F
/*compute_type*/
,
algo
));
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_
int
flags
)
{
auto
stream
=
c10
::
cuda
::
getCurrent
CUDA
Stream
();
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
int
flags
)
{
auto
stream
=
c10
::
hip
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
HIPBLAS_GEMM_DEFAULT
,
flags
);
}
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
...
...
apex/contrib/test/run_rocm_extensions.py
View file @
4e7a2a8e
...
...
@@ -4,6 +4,7 @@ import sys
test_dirs
=
[
"groupbn"
,
"fused_dense"
,
"layer_norm"
,
"multihead_attn"
,
"transducer"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
ROCM_BLACKLIST
=
[
"groupbn"
,
"layer_norm"
]
...
...
csrc/fused_dense_cuda.cu
View file @
4e7a2a8e
...
...
@@ -31,7 +31,7 @@ cublasStatus_t gemm_bias(
double
*
C
,
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
transa
,
transb
,
...
...
@@ -40,22 +40,18 @@ cublasStatus_t gemm_bias(
k
,
alpha
,
A
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
lda
,
B
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldb
,
beta
,
C
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldc
,
C
,
rocblas_datatype_f64_r
,
ldc
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
HIPBLAS_R_64F
,
HIPBLAS_GEMM_DEFAULT
);
#else
return
cublasGemmEx
(
handle
,
...
...
@@ -97,7 +93,7 @@ cublasStatus_t gemm_bias(
float
*
C
,
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
transa
,
transb
,
...
...
@@ -106,22 +102,18 @@ cublasStatus_t gemm_bias(
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
lda
,
B
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldb
,
beta
,
C
,
rocblas_datatype_f32_r
,
ldc
,
C
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
);
#else
return
cublasGemmEx
(
...
...
@@ -164,7 +156,7 @@ cublasStatus_t gemm_bias(
at
::
Half
*
C
,
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
transa
,
transb
,
...
...
@@ -173,22 +165,18 @@ cublasStatus_t gemm_bias(
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
lda
,
B
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldb
,
beta
,
C
,
rocblas_datatype_f16_r
,
ldc
,
C
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
0
);
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
);
#else
return
cublasGemmEx
(
handle
,
...
...
csrc/mlp_cuda.cu
View file @
4e7a2a8e
...
...
@@ -78,7 +78,7 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
transa
,
transb
,
...
...
@@ -87,22 +87,18 @@ cublasStatus_t mlp_gemm(
k
,
alpha
,
A
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
lda
,
B
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldb
,
beta
,
C
,
rocblas_datatype_f64_r
,
HIPBLAS_R_64F
,
ldc
,
C
,
rocblas_datatype_f64_r
,
ldc
,
rocblas_datatype_f64_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
HIPBLAS_R_64F
,
HIPBLAS_GEMM_DEFAULT
);
#else
return
cublasGemmEx
(
handle
,
...
...
@@ -145,7 +141,7 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
transa
,
transb
,
...
...
@@ -154,22 +150,18 @@ cublasStatus_t mlp_gemm(
k
,
alpha
,
A
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
lda
,
B
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldb
,
beta
,
C
,
rocblas_datatype_f32_r
,
ldc
,
C
,
rocblas_datatype_f32_r
,
HIPBLAS_R_32F
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
);
#else
return
cublasGemmEx
(
...
...
@@ -213,7 +205,7 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
return
roc
blas
_g
emm
_e
x
(
return
hip
blas
G
emm
E
x
(
handle
,
transa
,
transb
,
...
...
@@ -222,22 +214,18 @@ cublasStatus_t mlp_gemm(
k
,
alpha
,
A
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
lda
,
B
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldb
,
beta
,
C
,
rocblas_datatype_f16_r
,
ldc
,
C
,
rocblas_datatype_f16_r
,
HIPBLAS_R_16F
,
ldc
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
,
0
,
flag
);
HIPBLAS_R_32F
,
HIPBLAS_GEMM_DEFAULT
);
#else
return
cublasGemmEx
(
handle
,
...
...
tests/distributed/DDP/ddp_race_condition_test.py
View file @
4e7a2a8e
...
...
@@ -8,7 +8,7 @@ import os
parser
=
argparse
.
ArgumentParser
(
description
=
'allreduce hook example'
)
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
args
=
parser
.
parse_args
()
args
.
distributed
=
False
...
...
tests/distributed/amp_master_params/amp_master_params.py
View file @
4e7a2a8e
...
...
@@ -8,7 +8,7 @@ from apex.parallel import DistributedDataParallel
parser
=
argparse
.
ArgumentParser
()
# FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
# automatically by torch.distributed.launch.
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--opt_level"
,
default
=
"O2"
,
type
=
str
)
args
=
parser
.
parse_args
()
...
...
tests/distributed/synced_batchnorm/test_groups.py
View file @
4e7a2a8e
...
...
@@ -26,7 +26,7 @@ batch_size = 32
from
apex.parallel
import
DistributedDataParallel
as
DDP
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp64"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--group_size"
,
default
=
0
,
type
=
int
)
...
...
tests/distributed/synced_batchnorm/two_gpu_test_different_batch_size.py
View file @
4e7a2a8e
...
...
@@ -23,7 +23,7 @@ def compare(desc, inp1, inp2, error= 1e-5):
return
close
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--local
_
rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--local
-
rank'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--apex'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
...
...
tests/distributed/synced_batchnorm/two_gpu_unit_test.py
View file @
4e7a2a8e
...
...
@@ -26,7 +26,7 @@ batch_size = 32
from
apex.parallel
import
DistributedDataParallel
as
DDP
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--local
_
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--local
-
rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--fp16"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--fp64"
,
action
=
'store_true'
,
default
=
False
)
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment