Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a3c14b5f
Commit
a3c14b5f
authored
Jun 08, 2023
by
guangzlu
Browse files
removed global shuffle apis
parent
9e16e38e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
78 deletions
+63
-78
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+63
-78
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
a3c14b5f
...
@@ -39,7 +39,6 @@ template <typename GridwiseGemm,
...
@@ -39,7 +39,6 @@ template <typename GridwiseGemm,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
,
typename
LSEGridDescriptor_M
,
typename
LSEGridDescriptor_M
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
typename
ComputeBasePtrOfStridedBatch
,
...
@@ -71,8 +70,6 @@ __global__ void
...
@@ -71,8 +70,6 @@ __global__ void
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
const
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
LSEGridDescriptor_M
lse_grid_desc_m
,
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
index_t
batch_count
,
...
@@ -130,7 +127,6 @@ __global__ void
...
@@ -130,7 +127,6 @@ __global__ void
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
...
@@ -163,7 +159,6 @@ __global__ void
...
@@ -163,7 +159,6 @@ __global__ void
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
...
@@ -653,10 +648,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -653,10 +648,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n_
);
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_
=
GridwiseGemm
::
MakeCShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
(
z_grid_desc_m_n_
);
if
(
p_lse_grid
==
nullptr
)
if
(
p_lse_grid
==
nullptr
)
{
{
is_lse_storing_
=
false
;
is_lse_storing_
=
false
;
...
@@ -706,9 +697,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -706,9 +697,6 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
typename
GridwiseGemm
::
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_
;
// block-to-c-tile map
// block-to-c-tile map
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
@@ -765,72 +753,69 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -765,72 +753,69 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
launch_kernel
=
auto
is_dropout_
,
[
&
](
auto
has_main_k_block_loop_
,
auto
is_dropout_
,
auto
is_lse_storing_
)
{
auto
is_lse_storing_
)
{
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
const
auto
kernel
=
kernel_batched_multiheadattention_forward_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
GemmAccDataType
,
GemmAccDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
,
DeviceOp
::
LSEGridDesc_M
,
typename
GridwiseGemm
::
ZShuffleGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_M4_N4_N5
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
DeviceOp
::
LSEGridDesc_M
,
ComputeBasePtrOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
C0MatrixMask
,
ComputeBasePtrOfStridedBatch
,
has_main_k_block_loop_
,
C0MatrixMask
,
is_dropout_
,
has_main_k_block_loop_
,
is_lse_storing_
,
is_dropout_
,
Deterministic
>
;
is_lse_storing_
,
Deterministic
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
kernel
,
stream_config
,
dim3
(
grid_size
),
kernel
,
dim3
(
BlockSize
),
dim3
(
grid_size
),
0
,
dim3
(
BlockSize
),
arg
.
p_a_grid_
,
0
,
arg
.
p_b_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_z_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_z_grid_
,
arg
.
a_element_op_
,
arg
.
p_lse_grid_
,
arg
.
b_element_op_
,
arg
.
a_element_op_
,
arg
.
acc_element_op_
,
arg
.
b_element_op_
,
arg
.
b1_element_op_
,
arg
.
acc_element_op_
,
arg
.
c_element_op_
,
arg
.
b1_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
c_element_op_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
lse_grid_desc_m_
,
arg
.
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
,
arg
.
block_2_ctile_map_
,
arg
.
z_grid_shuffle_desc_m0_n0_m1_n1_m2_n2_m3_n3_m4_n4_n5_
,
arg
.
batch_count_
,
arg
.
lse_grid_desc_m_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
block_2_ctile_map_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
batch_count_
,
arg
.
c0_matrix_mask_
,
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
),
arg
.
p_dropout_in_16bits_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
p_dropout_rescale_
,
arg
.
c0_matrix_mask_
,
arg
.
seed_
,
arg
.
p_dropout_in_16bits_
,
arg
.
offset_
,
arg
.
p_dropout_rescale_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
seed_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
arg
.
offset_
,
};
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
// to concern Gemm0's loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment