Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
db7007ae
Commit
db7007ae
authored
Nov 14, 2022
by
flyingdown
Browse files
modify rocblas_gemm_ex's compute_type to rocblas_datatype_f16_r for fp16
parent
32ab028c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
175 additions
and
93 deletions
+175
-93
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+21
-15
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+21
-15
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+20
-13
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+20
-13
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+18
-12
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+18
-12
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+46
-7
apex/contrib/multihead_attn/encdec_multihead_attn.py
apex/contrib/multihead_attn/encdec_multihead_attn.py
+1
-0
csrc/fused_dense_cuda.cu
csrc/fused_dense_cuda.cu
+5
-3
csrc/mlp_cuda.cu
csrc/mlp_cuda.cu
+5
-3
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
db7007ae
...
@@ -42,9 +42,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -42,9 +42,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
// sequentially dependent
...
@@ -110,7 +113,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -110,7 +113,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_q_dim
,
output_lin_q_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -136,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -136,7 +139,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_kv_dim
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -239,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -239,7 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -278,9 +281,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -278,9 +281,12 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
// in my first attempt to create the code
...
@@ -352,7 +358,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -352,7 +358,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -378,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -378,7 +384,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -513,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -513,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -539,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -539,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -565,7 +571,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -565,7 +571,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -591,7 +597,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -591,7 +597,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
db7007ae
...
@@ -51,9 +51,12 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -51,9 +51,12 @@ std::vector<torch::Tensor> fwd_cuda(
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
// sequentially dependent
...
@@ -137,7 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -137,7 +140,7 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_q_dim
,
output_lin_q_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -163,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -163,7 +166,7 @@ std::vector<torch::Tensor> fwd_cuda(
k_lin_results_ptr
,
k_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_kv_dim
,
output_lin_kv_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -266,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -266,7 +269,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -330,9 +333,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -330,9 +333,12 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
// in my first attempt to create the code
...
@@ -416,7 +422,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -416,7 +422,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -442,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -442,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -578,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -578,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -604,7 +610,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -604,7 +610,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_q_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -630,7 +636,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -630,7 +636,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -656,7 +662,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -656,7 +662,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_kv_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
db7007ae
...
@@ -37,10 +37,14 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -37,10 +37,14 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta_zero
=
0.0
;
// const float beta_zero = 0.0;
const
float
beta_one
=
1.0
;
// const float beta_one = 1.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta_zero
=
0.0
;
const
half
beta_one
=
1.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
// sequentially dependent
...
@@ -106,7 +110,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -106,7 +110,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -203,7 +207,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -203,7 +207,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -231,9 +235,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -231,9 +235,12 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
// in my first attempt to create the code
...
@@ -301,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -301,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -327,7 +334,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -327,7 +334,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
db7007ae
...
@@ -36,10 +36,14 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -36,10 +36,14 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta_zero
=
0.0
;
// const float beta_zero = 0.0;
const
float
beta_one
=
1.0
;
// const float beta_one = 1.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta_zero
=
0.0
;
const
half
beta_one
=
1.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
// sequentially dependent
...
@@ -104,7 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -104,7 +108,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -209,7 +213,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
...
@@ -209,7 +213,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -237,9 +241,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -237,9 +241,12 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
// in my first attempt to create the code
...
@@ -307,7 +314,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -307,7 +314,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -333,7 +340,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -333,7 +340,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -461,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -487,7 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
db7007ae
...
@@ -36,9 +36,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -36,9 +36,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
// sequentially dependent
...
@@ -102,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -102,7 +105,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -205,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -205,7 +208,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -233,9 +236,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -233,9 +236,12 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
// in my first attempt to create the code
...
@@ -303,7 +309,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -303,7 +309,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -329,7 +335,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -329,7 +335,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -464,7 +470,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -464,7 +470,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -490,7 +496,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -490,7 +496,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
db7007ae
...
@@ -40,9 +40,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -40,9 +40,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
// sequentially dependent
...
@@ -124,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -124,7 +127,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -228,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
...
@@ -228,7 +231,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -280,9 +283,12 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -280,9 +283,12 @@ std::vector<torch::Tensor> bwd_cuda(
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
// const float alpha = 1.0;
const
float
beta
=
0.0
;
// const float beta = 0.0;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
// in my first attempt to create the code
...
@@ -361,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -361,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -387,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -387,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -523,7 +529,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -523,7 +529,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
@@ -550,7 +556,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -550,7 +556,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
0
/*solution_index*/
,
flags
));
flags
));
...
...
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
View file @
db7007ae
...
@@ -42,9 +42,48 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
...
@@ -42,9 +42,48 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
}
}
// void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
// cublasOperation_t opa = convertTransToCublasOperation(transa);
// cublasOperation_t opb = convertTransToCublasOperation(transb);
// cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
// cublasSetStream(handle, stream);
// float fAlpha = alpha;
// float fBeta = beta;
// //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
// opa, opb, (int)m, (int)n, (int)k,
// (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
// b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
// (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
// d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
// (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
// }
// void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
// auto stream = c10::cuda::getCurrentCUDAStream();
// if ( (transa == 't') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 't') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else {
// AT_ASSERTM(false, "TransA and TransB are invalid");
// }
// }
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
half
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
half
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cublasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
...
@@ -56,16 +95,16 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
...
@@ -56,16 +95,16 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_strided_batched_ex
(
handle
,
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_strided_batched_ex
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fA
lpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
(
void
*
)
&
a
lpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fB
eta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
(
void
*
)
&
b
eta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
rocblas_datatype_f
32
_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
(
int
)
batchCount
,
rocblas_datatype_f
16
_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
}
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
half
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_int
flags
)
{
half
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_int
flags
)
{
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
...
...
apex/contrib/multihead_attn/encdec_multihead_attn.py
View file @
db7007ae
...
@@ -151,6 +151,7 @@ class EncdecMultiheadAttn(nn.Module):
...
@@ -151,6 +151,7 @@ class EncdecMultiheadAttn(nn.Module):
self
.
dropout
,
self
.
dropout
,
)
)
if
is_training
:
if
is_training
:
print
(
'default:'
,
outputs
)
outputs
=
jit_dropout_add
(
outputs
,
query
,
self
.
dropout
,
is_training
)
outputs
=
jit_dropout_add
(
outputs
,
query
,
self
.
dropout
,
is_training
)
else
:
else
:
outputs
=
outputs
+
query
outputs
=
outputs
+
query
...
...
csrc/fused_dense_cuda.cu
View file @
db7007ae
...
@@ -164,6 +164,8 @@ cublasStatus_t gemm_bias(
...
@@ -164,6 +164,8 @@ cublasStatus_t gemm_bias(
at
::
Half
*
C
,
at
::
Half
*
C
,
int
ldc
)
{
int
ldc
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
half
hAlpha
=
__float2half
(
*
alpha
);
half
hBeta
=
__float2half
(
*
beta
);
return
rocblas_gemm_ex
(
return
rocblas_gemm_ex
(
handle
,
handle
,
transa
,
transa
,
...
@@ -171,21 +173,21 @@ cublasStatus_t gemm_bias(
...
@@ -171,21 +173,21 @@ cublasStatus_t gemm_bias(
m
,
m
,
n
,
n
,
k
,
k
,
a
lpha
,
/* alpha */
&
hA
lpha
,
A
,
A
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
ldb
,
ldb
,
b
eta
,
/* beta */
&
hB
eta
,
C
,
C
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
ldc
,
ldc
,
C
,
C
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
ldc
,
ldc
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
0
);
0
);
...
...
csrc/mlp_cuda.cu
View file @
db7007ae
...
@@ -211,6 +211,8 @@ cublasStatus_t mlp_gemm(
...
@@ -211,6 +211,8 @@ cublasStatus_t mlp_gemm(
int
ldc
,
int
ldc
,
int
flag
)
{
int
flag
)
{
#ifdef __HIP_PLATFORM_HCC__
#ifdef __HIP_PLATFORM_HCC__
half
hAlpha
=
__float2half
(
*
alpha
);
half
hBeta
=
__float2half
(
*
beta
);
return
rocblas_gemm_ex
(
return
rocblas_gemm_ex
(
handle
,
handle
,
transa
,
transa
,
...
@@ -218,21 +220,21 @@ cublasStatus_t mlp_gemm(
...
@@ -218,21 +220,21 @@ cublasStatus_t mlp_gemm(
m
,
m
,
n
,
n
,
k
,
k
,
a
lpha
,
/* alpha */
&
hA
lpha
,
A
,
A
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
lda
,
lda
,
B
,
B
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
ldb
,
ldb
,
b
eta
,
/* beta */
&
hB
eta
,
C
,
C
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
ldc
,
ldc
,
C
,
C
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
ldc
,
ldc
,
rocblas_datatype_f32_r
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
0
,
0
,
flag
);
flag
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment