Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
227be6be
Commit
227be6be
authored
Sep 19, 2023
by
root
Committed by
flyingdown
Sep 19, 2023
Browse files
revert multihead_attn to fp32_r
parent
412a8ac5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
100 additions
and
433 deletions
+100
-433
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+44
-176
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+47
-180
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+4
-11
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+0
-7
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+5
-11
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+0
-6
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
+0
-42
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
227be6be
...
...
@@ -42,13 +42,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
...
...
@@ -285,9 +281,6 @@ std::vector<torch::Tensor> bwd_cuda(
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
...
...
@@ -390,176 +383,51 @@ std::vector<torch::Tensor> bwd_cuda(
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
h_beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
h_alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
h_beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
}
else
{
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
}
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
227be6be
...
...
@@ -51,13 +51,9 @@ std::vector<torch::Tensor> fwd_cuda(
const
int
batch_stride_q
=
head_dim
;
const
int
batch_stride_kv
=
2
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
...
...
@@ -337,9 +333,6 @@ std::vector<torch::Tensor> bwd_cuda(
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
...
...
@@ -454,177 +447,51 @@ std::vector<torch::Tensor> bwd_cuda(
0
/*solution_index*/
,
flags
));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
h_alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
h_beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
/* rocblas_datatype_f32_r */
rocblas_datatype_f16_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
h_alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
h_beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
h_alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
h_beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
}
else
{
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_N
),
embed_dim
,
batches_q
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK
(
rocBLASStatusToHIPStatus
(
rocblas_gemm_ex
((
rocblas_handle
)
handle
,
hipOperationToRocOperation
(
CUBLAS_OP_N
),
hipOperationToRocOperation
(
CUBLAS_OP_T
),
embed_dim
,
embed_dim
,
batches_q
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
/*a_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
dropout_add_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*b_type*/
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*c_type*/
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
/*d_type*/
,
embed_dim
,
rocblas_datatype_f32_r
/*compute_type*/
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
)));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
a_layout_t
,
b_layout_n
,
k_seq_len
,
q_seq_len
,
head_dim
,
alpha
,
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim_kv
,
batch_stride_kv
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
beta
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
,
flags
);
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
}
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
a_layout_n
,
b_layout_t
,
head_dim
,
k_seq_len
,
q_seq_len
,
alpha
,
static_cast
<
const
half
*>
(
output_lin_grads
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
v_lin_grads_ptr
,
lead_dim_kv
,
batch_stride_kv
,
attn_batches
,
flags
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
227be6be
...
...
@@ -37,14 +37,10 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta_zero
=
0.0
;
const
half
h_beta_one
=
1.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
...
...
@@ -238,9 +234,6 @@ std::vector<torch::Tensor> bwd_cuda(
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
227be6be
...
...
@@ -40,10 +40,6 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta_zero
=
0.0
;
const
half
h_beta_one
=
1.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
...
...
@@ -244,9 +240,6 @@ std::vector<torch::Tensor> bwd_cuda(
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
227be6be
...
...
@@ -36,12 +36,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
int
lead_dim
=
attn_batches
*
3
*
head_dim
;
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const
half
alpha
=
1.0
;
const
half
beta
=
0.0
;
const
half
scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
...
...
@@ -105,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
output_lin_dim
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
...
...
@@ -208,7 +205,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
/*
rocblas_datatype_f32_r
*/
rocblas_datatype_f16_r
,
rocblas_datatype_f32_r
,
rocblas_gemm_algo_standard
/*algo*/
,
0
/*solution_index*/
,
flags
));
...
...
@@ -239,9 +236,6 @@ std::vector<torch::Tensor> bwd_cuda(
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
227be6be
...
...
@@ -43,9 +43,6 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
...
...
@@ -286,9 +283,6 @@ std::vector<torch::Tensor> bwd_cuda(
const
float
alpha
=
1.0
;
const
float
beta
=
0.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
const
half
h_alpha
=
1.0
;
const
half
h_beta
=
0.0
;
const
half
h_scale
=
__float2half
(
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
...
...
apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh
View file @
227be6be
...
...
@@ -10,7 +10,6 @@
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "utils.h"
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
...
...
@@ -29,8 +28,6 @@ int32_t solution_index = 0;
rocblas_int flags = 0;
*/
static
bool
use_fp16
=
parseEnvVarFlag
(
"APEX_APEX_ROCBLAS_GEMM_ALLOW_HALF"
);
namespace
{
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
if
(
trans
==
't'
)
...
...
@@ -84,45 +81,6 @@ void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
}
}
void
RocblasStridedBatchedGemm
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
half
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
half
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_gemm_algo
algo
,
rocblas_int
flags
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cublasSetStream
(
handle
,
stream
);
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK
(
rocblas_gemm_strided_batched_ex
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
alpha
,
a
,
rocblas_datatype_f16_r
/*a_type*/
,
(
int
)
lda
,
strideA
,
b
,
rocblas_datatype_f16_r
/*b_type*/
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
beta
,
c
,
rocblas_datatype_f16_r
/*c_type*/
,
(
int
)
ldc
,
strideC
,
d
,
rocblas_datatype_f16_r
/*d_type*/
,
int
(
ldd
),
strideD
,
(
int
)
batchCount
,
rocblas_datatype_f16_r
/*compute_type*/
,
algo
,
0
/*solution_index*/
,
flags
));
}
void
gemm_switch_fp32accum
(
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
half
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
half
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
half
*
d
,
long
ldd
,
long
strideD
,
long
batchCount
,
rocblas_int
flags
)
{
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
else
{
RocblasStridedBatchedGemm
(
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
d
,
ldd
,
strideD
,
batchCount
,
rocblas_gemm_algo_standard
,
flags
);
}
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment