Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
70fca3ed
Commit
70fca3ed
authored
May 19, 2023
by
ltqin
Browse files
one block dk pass
parent
7af1b43a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
116 additions
and
110 deletions
+116
-110
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
+9
-9
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
...tched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
+107
-101
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2.hpp
View file @
70fca3ed
...
...
@@ -1448,7 +1448,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1
,
false
>
{
lse_grid_desc_mblock_mrepeat_mwave_mperxdl
,
make_multi_index
(
block_work_idx_m
,
// mblock
make_multi_index
(
block_work_idx_m
,
// mblock
acc0_thread_origin
[
I0
],
// mrepeat
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
...
...
@@ -1511,14 +1511,14 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
block_work_idx_m
,
// MBlockId
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
,
// NBlockId
0
,
// mrepeat
0
,
// nrepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I1
],
// NWaveId
wave_m_n_id
[
I1
],
// MPerXdl
0
,
// group
wave_m_n_id
[
I0
],
// NInputIndex
0
),
tensor_operation
::
element_wise
::
PassThrough
{}};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt2v2.hpp
View file @
70fca3ed
...
...
@@ -1579,6 +1579,13 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
vgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
Scale
{
rp_dropout
});
// dK:
// dV: blockwise gemm
auto
k_grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
k_grad_thread_buf
=
k_grad_blockwise_gemm
.
GetCThreadBuffer
();
k_grad_thread_buf
.
Clear
();
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
KGradGemmTile_N_K_M
::
MakeQGridDesc_M0_K_M1
(
q_grid_desc_k0_m_k1
);
...
...
@@ -1959,50 +1966,92 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
});
// end gemm dV
// // gemm dP
// block_sync_lds();
// // dP = dY * V^T
// // assume size K == size O so HasMainKBlockLoop is the same
// gemm0_gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(
// ygrad_grid_desc_o0_m_o1,
// Gemm0::a_block_desc_ak0_m_ak1, // reuse
// pgrad_gemm_tile_ygrad_blockwise_copy,
// ygrad_grid_buf,
// gemm0_a_block_buf, // reuse
// Gemm0::a_block_slice_copy_step, // reuse
// v_grid_desc_o0_n_o1,
// Gemm0::b_block_desc_bk0_n_bk1, // reuse
// pgrad_gemm_tile_v_blockwise_copy,
// v_grid_buf,
// gemm0_b_block_buf, // reuse
// Gemm0::b_block_slice_copy_step, // reuse
// pgrad_blockwise_gemm,
// pgrad_thread_buf,
// num_o_block_main_loop);
// // dS = P * (dP - Y_dot_dY)
// auto& sgrad_thread_buf = pgrad_thread_buf;
// constexpr auto pgrad_thread_tile_iterator =
// pgrad_blockwise_gemm.MakeCThreadTileIterator();
// constexpr auto pgrad_thread_idx_to_m_n_adaptor =
// pgrad_blockwise_gemm.MakeCThreadIndexAdaptor8DTo2D();
// static_for<0, pgrad_thread_tile_iterator.GetNumOfAccess(), 1>{}([&](auto i) {
// constexpr auto pgrad_thread_idx = pgrad_thread_tile_iterator.GetIndex(i);
// constexpr auto m =
// pgrad_thread_idx_to_m_n_adaptor.CalculateBottomIndex(pgrad_thread_idx)[I0];
// // dS and P has same thread buf layout
// if(s_slash_p_thread_buf[i] >= 0)
// {
// sgrad_thread_buf(i) =
// s_slash_p_thread_buf[i] *
// (pgrad_thread_buf[i] - y_dot_ygrad_thread_buf[Number<m>{}]);
// }
// else
// {
// sgrad_thread_buf(i) =
// s_slash_p_thread_buf[i] * y_dot_ygrad_thread_buf[Number<m>{}];
// }
// });
// gemm dP
block_sync_lds
();
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
// reuse
pgrad_gemm_tile_ygrad_blockwise_copy
,
ygrad_grid_buf
,
gemm0_a_block_buf
,
// reuse
Gemm0
::
a_block_slice_copy_step
,
// reuse
v_grid_desc_o0_n_o1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
// reuse
pgrad_gemm_tile_v_blockwise_copy
,
v_grid_buf
,
gemm0_b_block_buf
,
// reuse
Gemm0
::
b_block_slice_copy_step
,
// reuse
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
num_o_block_main_loop
);
// dS = P * (dP - Y_dot_dY)
auto
&
sgrad_thread_buf
=
pgrad_thread_buf
;
constexpr
auto
pgrad_thread_tile_iterator
=
pgrad_blockwise_gemm
.
MakeCThreadTileIterator
();
constexpr
auto
pgrad_thread_idx_to_m_n_adaptor
=
pgrad_blockwise_gemm
.
MakeCThreadIndexAdaptor8DTo2D
();
static_for
<
0
,
pgrad_thread_tile_iterator
.
GetNumOfAccess
(),
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
pgrad_thread_idx
=
pgrad_thread_tile_iterator
.
GetIndex
(
i
);
constexpr
auto
m
=
pgrad_thread_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
pgrad_thread_idx
)[
I0
];
// dS and P has same thread buf layout
if
(
s_slash_p_thread_buf
[
i
]
>=
0
)
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
(
pgrad_thread_buf
[
i
]
-
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}]);
}
else
{
sgrad_thread_buf
(
i
)
=
s_slash_p_thread_buf
[
i
]
*
y_dot_ygrad_thread_buf
[
Number
<
m
>
{}];
}
});
// dK = scalar * dS^T * Q
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
// load KGrad Gemm A
const
auto
sgrad_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
sgrad_slice_idx
[
I2
],
sgrad_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
constexpr
auto
nwave_range
=
make_tuple
(
sgrad_slice_idx
[
I3
],
sgrad_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
sgrad_slice_idx
[
I0
],
sgrad_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
gemm2_a_block_buf
);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm2
::
b_block_slice_copy_step
);
block_sync_lds
();
// sync before write
kgrad_gemm_tile_q_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_m0_o_m1
,
gemm2_b_block_buf
);
block_sync_lds
();
// sync before read
k_grad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_block_buf
,
k_grad_thread_buf
);
});
// end gemm dK
// // gemm dQ
// // dQ = scalar * dS * K
...
...
@@ -2069,57 +2118,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// }
// } // end gemm dQ
// // dK = scalar * dS^T * Q
// v_slash_k_grad_thread_buf.Clear();
// static_for<0, num_gemm2_loop, 1>{}([&](auto gemm2_loop_idx) { // gemm dK
// // load KGrad Gemm B
// kgrad_gemm_tile_q_blockwise_copy.RunRead(q_grid_desc_m0_k_m1, q_grid_buf);
// // load KGrad Gemm A
// const auto sgrad_slice_idx =
// Gemm2::ASrcBlockSliceWindowIterator::GetIndexTupleOfNumber(gemm2_loop_idx);
// constexpr auto mwave_range =
// make_tuple(sgrad_slice_idx[I2],
// sgrad_slice_idx[I2] +
// Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I2));
// constexpr auto nwave_range =
// make_tuple(sgrad_slice_idx[I3],
// sgrad_slice_idx[I3] +
// Gemm2Params_N_O_M::ABlockSliceLengths_M0_N0_M1_N1::At(I3));
// if(gemm2_a_copy_subgroup.IsBelong(mwave_range, nwave_range))
// {
// kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds.Run(
// Gemm2::a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// make_tuple(
// sgrad_slice_idx[I0], sgrad_slice_idx[I1], I0, I0, I0, I0, I0, I0),
// sgrad_thread_buf,
// Gemm2::a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
// gemm2_a_block_buf);
// }
// // kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic
// buffer
// // sgrad slice window is moved by loop index
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(q_grid_desc_m0_k_m1,
// Gemm2::b_block_slice_copy_step);
// block_sync_lds(); // sync before write
// kgrad_gemm_tile_q_blockwise_copy.RunWrite(Gemm2::b_block_desc_m0_o_m1,
// gemm2_b_block_buf);
// block_sync_lds(); // sync before read
// v_slash_k_grad_blockwise_gemm.Run(
// gemm2_a_block_buf, gemm2_b_block_buf, v_slash_k_grad_thread_buf);
// }); // end gemm dK
// // atomic_add dK
// kgrad_thread_copy_vgpr_to_global.Run(Gemm2::c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
// v_slash_k_grad_thread_buf,
// kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4,
// kgrad_grid_buf);
// move slice window
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
...
...
@@ -2131,17 +2129,19 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
// ygrad_grid_desc_m0_o_m1,
// Gemm2::b_block_reset_copy_step); // rewind M
// vgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step N
// vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step
// N
pgrad_gemm_tile_ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_o0_m_o1
,
pgrad_gemm_tile_ygrad_block_reset_copy_step
);
// rewind O
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_o0_n_o1
,
pgrad_gemm_tile_v_block_reset_copy_step
);
// rewind O and step N
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
// kgrad_gemm_tile_q_blockwise_copy.MoveSrcSliceWindow(
// q_grid_desc_m0_k_m1,
// Gemm2::b_block_reset_copy_step); // rewind M
// kgrad_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4, Gemm2::c_block_slice_copy_step); // step
// N
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
...
@@ -2155,6 +2155,12 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
v_slash_k_grad_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
// atomic_add dK
kgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
k_grad_thread_buf
,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_grid_buf
);
ignore
=
c_element_op
;
ignore
=
qgrad_grid_buf
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment