Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7af1b43a
Commit
7af1b43a
authored
May 19, 2023
by
ltqin
Browse files
one block dv pass
parent
b281dac5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
115 deletions
+115
-115
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_V2R2.cpp
...oftmax_gemm/batched_multihead_attention_backward_V2R2.cpp
+11
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
...tched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
+102
-103
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+1
-1
No files found.
example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt
View file @
7af1b43a
...
@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult
...
@@ -12,7 +12,7 @@ add_example_executable(example_batched_multihead_attention_backward batched_mult
add_example_executable
(
example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp
)
add_example_executable
(
example_grouped_multihead_attention_train grouped_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_train batched_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_train batched_multihead_attention_train.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_V
1
R2 batched_multihead_attention_backward_V2R2.cpp
)
add_example_executable
(
example_batched_multihead_attention_backward_V
2
R2 batched_multihead_attention_backward_V2R2.cpp
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_custom_target
(
example_gemm_scale_softmax_gemm
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
add_dependencies
(
example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_V2R2.cpp
View file @
7af1b43a
...
@@ -350,9 +350,9 @@ using DeviceGemmInstance =
...
@@ -350,9 +350,9 @@ using DeviceGemmInstance =
2
,
// B1K1
2
,
// B1K1
32
,
// MPerXDL
32
,
// MPerXDL
32
,
// NPerXDL
32
,
// NPerXDL
4
,
// MXdlPerWave
1
,
// MXdlPerWave
1
,
// NXdlPerWave
4
,
// NXdlPerWave
1
,
// Gemm1NXdlPerWave
4
,
// Gemm1NXdlPerWave
2
,
// Gemm2NXdlPerWave
2
,
// Gemm2NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -375,8 +375,8 @@ using DeviceGemmInstance =
...
@@ -375,8 +375,8 @@ using DeviceGemmInstance =
4
,
4
,
2
,
2
,
false
,
false
,
4
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleMXdlPerWavePerShuffle
1
,
// CShuffleNXdlPerWavePerShuffle
4
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpec
,
// MaskingSpecialization
MaskingSpec
,
// MaskingSpecialization
...
@@ -501,17 +501,17 @@ int run(int argc, char* argv[])
...
@@ -501,17 +501,17 @@ int run(int argc, char* argv[])
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g_m_o = Softmax(alpha * Q_g_m_k * K_g_k_n) * V_g_n_o
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_g1_m_o = reshape(y_g_m_o, [G0, G1, M, O])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
// y_g0_m_g1_o = permute(y_g0_g1_m_o, [0, 2, 1, 3])
ck
::
index_t
M
=
5
12
;
ck
::
index_t
M
=
12
8
;
ck
::
index_t
N
=
5
12
;
ck
::
index_t
N
=
12
8
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
ck
::
index_t
G0
=
1
;
ck
::
index_t
G1
=
6
;
ck
::
index_t
G1
=
1
;
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0
.2
;
float
p_drop
=
0
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[])
...
@@ -1040,6 +1040,7 @@ int run(int argc, char* argv[])
"error"
,
"error"
,
1e-2
,
1e-2
,
1e-2
);
1e-2
);
//std::cout << vgrad_gs_os_ns_device_result << std::endl;
}
}
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
return
pass
?
((
void
)(
std
::
cout
<<
"pass
\n
"
),
0
)
:
((
void
)(
std
::
cout
<<
"fail
\n
"
),
1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
View file @
7af1b43a
...
@@ -903,8 +903,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -903,8 +903,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// AccessOrder
7
,
// VectorDim
7
,
// VectorDim
2
,
// ScalarPerVector
2
,
// ScalarPerVector
InMemoryDataOperationEnum
::
AtomicAdd
,
// GlobalMemoryDataOperation
InMemoryDataOperationEnum
::
Set
,
// GlobalMemoryDataOperation
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
;
true
>
;
};
};
...
@@ -1177,7 +1177,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1177,7 +1177,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
const
index_t
block_idx_
m
)
const
index_t
block_idx_
n
)
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
...
@@ -1217,11 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1217,11 +1217,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
return
;
return
;
}
}
const
index_t
block_work_idx_
m
=
Deterministic
?
block_idx_
m
:
block_work_idx
[
I0
];
const
index_t
block_work_idx_
n
=
Deterministic
?
block_idx_
n
:
block_work_idx
[
I0
];
// HACK: this force m/o_block_data_idx_on_grid into SGPR
// HACK: this force m/o_block_data_idx_on_grid into SGPR
const
index_t
m
_block_data_idx_on_grid
=
const
index_t
n
_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx_
m
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx_
n
*
NPerBlock
);
const
index_t
o_block_data_idx_on_grid
=
const
index_t
o_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
...
@@ -1254,7 +1254,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1254,7 +1254,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
s_gemm_tile_q_blockwise_copy
=
auto
s_gemm_tile_q_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmM dimension
a_element_op
,
a_element_op
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
...
@@ -1264,7 +1264,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1264,7 +1264,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
s_gemm_tile_k_blockwise_copy
=
auto
s_gemm_tile_k_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
k_grid_desc_k0_n_k1
,
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
...
@@ -1276,9 +1276,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1276,9 +1276,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
s_gemm_tile_a_block_reset_copy_step
=
const
auto
s_gemm_tile_a_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
MPerBlock
,
0
);
const
auto
s_gemm_tile_b_block_reset_copy_step
=
const
auto
s_gemm_tile_b_block_reset_copy_step
=
make_multi_index
(
-
k_grid_desc_k0_n_k1
.
GetLength
(
I0
),
NPerBlock
,
0
);
make_multi_index
(
-
k_grid_desc_k0_n_k1
.
GetLength
(
I0
),
0
,
0
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
...
@@ -1293,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1293,7 +1293,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmM dimension
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
a_block_desc_ak0_m_ak1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
...
@@ -1303,7 +1303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1303,7 +1303,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
pgrad_gemm_tile_v_blockwise_copy
=
auto
pgrad_gemm_tile_v_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
b_block_desc_bk0_n_bk1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
...
@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1
,
1
,
false
>
{
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx_m
,
// mblock
make_multi_index
(
0
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
acc0_thread_origin
[
I4
])};
// mperxdl
...
@@ -1510,15 +1510,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1510,15 +1510,15 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
make_multi_index
(
0
,
// MBlockId
0
,
// NBlockId
block_work_idx_n
,
// NBlockId
0
,
// mrepeat
0
,
// mrepeat
0
,
// nrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
wave_m_n_id
[
I0
],
// NInputIndex
0
),
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
...
@@ -1551,7 +1551,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1551,7 +1551,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
auto
vgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
ygrad_grid_desc_m0_o_m1
)>(
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
ygrad_grid_desc_m0_o_m1
)>(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_desc_m0_o_m1
,
make_multi_index
(
m_block_data_idx_on_grid
/
Gemm2Params_N_O_M
::
B_M1
,
make_multi_index
(
0
,
// QLT
o_block_data_idx_on_grid
,
o_block_data_idx_on_grid
,
0
),
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
...
@@ -1563,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1563,6 +1563,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
v_slash_k_grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
v_slash_k_grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
v_slash_k_grad_thread_buf
=
v_slash_k_grad_blockwise_gemm
.
GetCThreadBuffer
();
auto
v_slash_k_grad_thread_buf
=
v_slash_k_grad_blockwise_gemm
.
GetCThreadBuffer
();
v_slash_k_grad_thread_buf
.
Clear
();
// dV: C VGPR-to-global copy
// dV: C VGPR-to-global copy
const
auto
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
const
auto
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
...
@@ -1597,7 +1598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1597,7 +1598,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
auto
kgrad_gemm_tile_q_blockwise_copy
=
auto
kgrad_gemm_tile_q_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
q_grid_desc_m0_k_m1
)>(
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
q_grid_desc_m0_k_m1
)>(
q_grid_desc_m0_k_m1
,
q_grid_desc_m0_k_m1
,
make_multi_index
(
m_block_data_idx_on_grid
/
Gemm2Params_N_O_M
::
B_M1
,
make_multi_index
(
0
,
// QLT
o_block_data_idx_on_grid
,
o_block_data_idx_on_grid
,
0
),
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
tensor_operation
::
element_wise
::
PassThrough
{},
...
@@ -1645,10 +1646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1645,10 +1646,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
auto
y_thread_data_on_block_idx
=
const
auto
y_thread_data_on_block_idx
=
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
y_thread_cluster_idx
*
y_thread_desc_m0_m1_o0_o1
.
GetLengths
();
const
auto
y_thread_data_on_grid_idx
=
const
auto
y_thread_data_on_grid_idx
=
y_thread_data_on_block_idx
;
// QLT
make_multi_index
(
block_work_idx_m
,
I0
,
I0
/* all WGs start from o_block_idx = 0 */
,
I0
)
+
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
...
@@ -1773,8 +1771,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1773,8 +1771,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
index_t
gemm1_m_block_outer_index
=
0
;
index_t
gemm1_m_block_outer_index
=
0
;
do
do
{
{
auto
n
_block_data_idx_on_grid
=
auto
m
_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
gemm1_m_block_outer_index
*
N
PerBlock
);
__builtin_amdgcn_readfirstlane
(
gemm1_m_block_outer_index
*
M
PerBlock
);
if
(
c0_matrix_mask
.
IsTileSkippable
(
if
(
c0_matrix_mask
.
IsTileSkippable
(
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
m_block_data_idx_on_grid
,
n_block_data_idx_on_grid
,
MPerBlock
,
NPerBlock
))
{
{
...
@@ -1912,62 +1910,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1912,62 +1910,54 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
ignore
=
gemm2_b_block_buf
;
ignore
=
gemm2_b_block_buf
;
ignore
=
v_slash_k_grad_thread_buf
;
ignore
=
v_slash_k_grad_thread_buf
;
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
// SubThreadBlock<BlockSize> gemm2_a_copy_subgroup(s_blockwise_gemm.GetWaveIdx()[I0],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
// s_blockwise_gemm.GetWaveIdx()[I1]);
constexpr
index_t
num_gemm2_loop
=
MPerBlock
/
Gemm2Params_N_O_M
::
Sum_M
;
// constexpr index_t num_gemm2_loop = MPerBlock / Gemm2Params_N_O_M::Sum_M;
static_assert
(
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetNumOfAccess
()
==
num_gemm2_loop
,
// static_assert(Gemm2::ASrcBlockSliceWindowIterator::GetNumOfAccess() == num_gemm2_loop,
""
);
// "");
// TODO: tune gemm2 pipeline
// // TODO: tune gemm2 pipeline
// dV = P_drop^T * dY
// // dV = P_drop^T * dY
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dV
// v_slash_k_grad_thread_buf.Clear();
// load VGrad Gemm B
// static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dV
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
// // load VGrad Gemm B
ygrad_grid_buf
);
// vgrad_gemm_tile_ygrad_blockwise_copy.RunRead(ygrad_grid_desc_m0_o_m1,
// ygrad_grid_buf);
// load VGrad Gemm A
const
auto
p_slice_idx
=
// // load VGrad Gemm A
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
// const auto p_slice_idx =
constexpr
auto
mwave_range
=
make_tuple
(
// Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
p_slice_idx
[
I2
],
// constexpr auto mwave_range = make_tuple(
p_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
// p_slice_idx[I2],
constexpr
auto
nwave_range
=
make_tuple
(
// p_slice_idx[I2] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
p_slice_idx
[
I3
],
// constexpr auto nwave_range = make_tuple(
p_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
// p_slice_idx[I3],
// p_slice_idx[I3] + Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
// if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
.
Run
(
// {
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
// vgrad_gemm_tile_p_thread_copy_vgpr_to_lds.Run(
make_tuple
(
p_slice_idx
[
I0
],
p_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
// Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
s_slash_p_thread_buf
,
// make_tuple(p_slice_idx[I0], p_slice_idx[I1], I0, I0, I0, I0, I0, I0),
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
// s_slash_p_thread_buf,
gemm2_a_block_buf
);
// Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
}
// gemm2_a_block_buf);
// }
// ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// p slice window is moved by loop index
// // ygrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
// // p slice window is moved by loop index
ygrad_grid_desc_m0_o_m1
,
Gemm2
::
b_block_slice_copy_step
);
// vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
// ygrad_grid_desc_m0_o_m1, Gemm2::b_block_slice_copy_step);
block_sync_lds
();
// sync before write
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_m0_o_m1
,
// block_sync_lds(); // sync before write
gemm2_b_block_buf
);
// vgrad_gemm_tile_ygrad_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
// gemm2_b_block_buf);
block_sync_lds
();
// sync before read
v_slash_k_grad_blockwise_gemm
.
Run
(
// block_sync_lds(); // sync before read
gemm2_a_block_buf
,
gemm2_b_block_buf
,
v_slash_k_grad_thread_buf
);
// v_slash_k_grad_blockwise_gemm.Run(
// gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
});
// end gemm dV
// }); // end gemm dV
// // atomic_add dV
// vgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
// v_slash_k_grad_thread_buf,
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// vgrad_grid_buf);
// // gemm dP
// // gemm dP
// block_sync_lds();
// block_sync_lds();
...
@@ -2019,9 +2009,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2019,9 +2009,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// {
// {
// // TODO: explore using dynamic buffer for a1 thread buffer
// // TODO: explore using dynamic buffer for a1 thread buffer
// // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given
// that
// // the A1 source buffer is static buffer holding the output of first GEMM and
// // the A1 source buffer is static buffer holding the output of first GEMM and
// // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// // requires constexpr offset by design. Therefore, we pass tensor coordinate
// offset
// // explicitly in Run() below.
// // explicitly in Run() below.
// // preload data into LDS
// // preload data into LDS
...
@@ -2040,12 +2032,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2040,12 +2032,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// {
// {
// static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
// static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
// qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
// qgrad_gemm_tile_sgrad_blockwise_copy.Run(Gemm1::a_src_thread_desc_k0_m_k1,
// Gemm1::a_block_slice_copy_step *
i,
// Gemm1::a_block_slice_copy_step *
// sgrad_thread_buf,
//
i,
sgrad_thread_buf,
// Gemm1::a_thread_desc_k0_m_k1,
// Gemm1::a_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// make_tuple(I0, I0, I0),
// gemm1_a_thread_buf);
// gemm1_a_thread_buf);
// qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1, k_grid_buf);
// qgrad_gemm_tile_k_blockwise_copy.RunRead(k_grid_desc_n0_k_n1,
// k_grid_buf);
// block_sync_lds();
// block_sync_lds();
...
@@ -2065,11 +2058,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2065,11 +2058,9 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// {
// {
// qgrad_gemm_tile_sgrad_blockwise_copy.Run(
// qgrad_gemm_tile_sgrad_blockwise_copy.Run(
// Gemm1::a_src_thread_desc_k0_m_k1,
// Gemm1::a_src_thread_desc_k0_m_k1,
// Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop - 1>{},
// Gemm1::a_block_slice_copy_step * Number<num_gemm1_k_block_inner_loop -
// sgrad_thread_buf,
// 1>{}, sgrad_thread_buf, Gemm1::a_thread_desc_k0_m_k1, make_tuple(I0, I0,
// Gemm1::a_thread_desc_k0_m_k1,
// I0), gemm1_a_thread_buf);
// make_tuple(I0, I0, I0),
// gemm1_a_thread_buf);
// block_sync_lds();
// block_sync_lds();
...
@@ -2107,7 +2098,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2107,7 +2098,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// gemm2_a_block_buf);
// gemm2_a_block_buf);
// }
// }
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic
// buffer
// // sgrad slice window is moved by loop index
// // sgrad slice window is moved by loop index
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
// Gemm2::b_block_slice_copy_step);
// Gemm2::b_block_slice_copy_step);
...
@@ -2135,11 +2127,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2135,11 +2127,11 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
k_grid_desc_k0_n_k1
,
s_gemm_tile_b_block_reset_copy_step
);
// rewind K and step N
s_gemm_tile_b_block_reset_copy_step
);
// rewind K and step N
vgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
//
vgrad_gemm_tile_ygrad_blockwise_copy.MoveSrcSliceWindow(
ygrad_grid_desc_m0_o_m1
,
//
ygrad_grid_desc_m0_o_m1,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
//
Gemm2::b_block_reset_copy_step); // rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
//
vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
//
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
...
@@ -2153,10 +2145,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -2153,10 +2145,17 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
}
while
(
++
gemm1_m_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_m_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// atomic_add dV
vgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
v_slash_k_grad_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
qgrad_grid_buf
;
ignore
=
qgrad_grid_buf
;
ignore
=
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
;
ignore
=
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
7af1b43a
...
@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -917,7 +917,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
block_sync_lds
();
block_sync_lds
();
}
}
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment