Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ec9c2b5e
Commit
ec9c2b5e
authored
Dec 26, 2022
by
Anthony Chang
Browse files
dK validates
parent
2d55c14c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
175 additions
and
30 deletions
+175
-30
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
...oftmax_gemm/batched_multihead_attention_backward_fp16.cpp
+5
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+170
-28
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_backward_fp16.cpp
View file @
ec9c2b5e
...
@@ -647,7 +647,10 @@ int run(int argc, char* argv[])
...
@@ -647,7 +647,10 @@ int run(int argc, char* argv[])
1e-2
);
1e-2
);
std
::
cout
<<
"Checking kgrad:
\n
"
;
std
::
cout
<<
"Checking kgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
kgrad_gs_ns_ks_device_result
.
mData
,
kgrad_gs_ns_ks_host_result
.
mData
);
kgrad_gs_ns_ks_host_result
.
mData
,
"error"
,
1e-2
,
1e-2
);
std
::
cout
<<
"Checking vgrad:
\n
"
;
std
::
cout
<<
"Checking vgrad:
\n
"
;
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
pass
&=
ck
::
utils
::
check_err
(
vgrad_gs_os_ns_device_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
vgrad_gs_os_ns_host_result
.
mData
,
...
@@ -656,7 +659,7 @@ int run(int argc, char* argv[])
...
@@ -656,7 +659,7 @@ int run(int argc, char* argv[])
1e-2
);
1e-2
);
}
}
return
pass
?
0
:
1
;
return
pass
?
(
std
::
cout
<<
"pass
\n
"
,
0
)
:
(
std
::
cout
<<
"fail
\n
"
,
1
)
;
}
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
ec9c2b5e
...
@@ -789,7 +789,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -789,7 +789,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
template
<
typename
CGradDesc_N_O
>
template
<
typename
CGradDesc_N_O
>
__host__
__device__
static
const
auto
__host__
__device__
static
const
auto
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
CGradDesc_N_O
c_grid_desc_n_o
)
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
const
CGradDesc_N_O
&
c_grid_desc_n_o
)
{
{
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// HACK: for unmerge transform, the length of highest dim is irrelevant so we put dummy
// variable I1 there
// variable I1 there
...
@@ -859,7 +859,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -859,7 +859,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
// PGrad Gemm has the same layout as P = Q * K^T Gemm (A row-major B col-major)
struct
PGradGemmTile_M_N_O
struct
PGradGemmTile_M_N_O
{
{
// TODO
ANT
:
// TODO:
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// Make all input tensors 2D and transform them into appropriate 3D form in kernel to make
// things more concise
// things more concise
template
<
typename
YGradGridDesc_M0_O_M1_
>
template
<
typename
YGradGridDesc_M0_O_M1_
>
...
@@ -957,6 +957,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -957,6 +957,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
};
};
struct
KGradGemmTile_N_K_M
{
// B position
template
<
typename
QGridDesc_K0_M_K1_
>
__device__
static
const
auto
MakeQGridDesc_M0_K_M1
(
const
QGridDesc_K0_M_K1_
&
q_grid_desc_k0_m_k1
)
{
const
auto
Q_K0
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
M
=
q_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
Q_K1
=
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
constexpr
auto
Q_M1
=
B1K1
;
const
auto
Q_M0
=
M
/
Q_M1
;
const
auto
q_grid_desc_m0_k_m1
=
transform_tensor_descriptor
(
q_grid_desc_k0_m_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Q_M0
,
Q_M1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
Q_K0
,
Q_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
q_grid_desc_m0_k_m1
;
}
// C position
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
const
auto
MakeKGradGridDesc_N_K
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
{
const
auto
K_K0
=
k_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
k_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K_K1
=
k_grid_desc_k0_n_k1
.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
k_grid_desc_k0_n_k1
,
make_tuple
(
make_pass_through_transform
(
N
),
make_merge_transform_v3_division_mod
(
make_tuple
(
K_K0
,
K_K1
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
};
struct
SharedMemTrait
struct
SharedMemTrait
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
...
@@ -1067,7 +1109,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1067,7 +1109,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_v_grid
,
v_grid_desc_n0_o_n1
.
GetElementSpaceSize
());
p_v_grid
,
v_grid_desc_n0_o_n1
.
GetElementSpaceSize
());
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
y_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
p_y_grid
,
y_grid_desc_mblock_mperblock_oblock_operblock
.
GetElementSpaceSize
());
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
lse_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
p_lse_grid
,
lse_grid_desc_m
.
GetElementSpaceSize
());
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
ygrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
p_ygrad_grid
,
ygrad_grid_desc_m0_o_m1
.
GetElementSpaceSize
());
...
@@ -1075,6 +1117,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1075,6 +1117,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
p_vgrad_grid
,
vgrad_grid_desc_n_o
.
GetElementSpaceSize
());
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
qgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_qgrad_grid
,
q_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
kgrad_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_kgrad_grid
,
k_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
// divide block work by [M, O]
// divide block work by [M, O]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -1095,6 +1139,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1095,6 +1139,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
o_block_data_idx_on_grid
=
const
index_t
o_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
// 6 GEMM operations are categorized into 3 buckets. SizeK == SizeO == head_dim
// S_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// dV_NOM / dK_NKM Gemm (Gemm2 crr)
//
//
// set up S / dP Gemm (type 1 rcr)
// set up S / dP Gemm (type 1 rcr)
//
//
...
@@ -1211,11 +1260,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1211,11 +1260,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
QGradGemmTile_M_K_N
::
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
QGradGemmTile_M_K_N
::
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
q_grid_desc_k0_m_k1
);
q_grid_desc_k0_m_k1
);
// dQ:
Gemm
A matrix blockwise copy
// dQ: A matrix blockwise copy
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ:
Gemm
B matrix blockwise copy
// dQ: B matrix blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
auto
qgrad_gemm_tile_k_blockwise_copy
=
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_n0_k_n1
)>(
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_n0_k_n1
)>(
k_grid_desc_n0_k_n1
,
k_grid_desc_n0_k_n1
,
...
@@ -1357,9 +1406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1357,9 +1406,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// dV: blockwise gemm
// dV: blockwise gemm
auto
vgrad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
v
_slash_k_
grad_blockwise_gemm
=
typename
Gemm2
::
BlockwiseGemm
{};
auto
vgrad_thread_buf
=
vgrad_blockwise_gemm
.
GetCThreadBuffer
();
auto
v
_slash_k_
grad_thread_buf
=
v
_slash_k_
grad_blockwise_gemm
.
GetCThreadBuffer
();
// dV: C VGPR-to-global copy
// dV: C VGPR-to-global copy
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
const
auto
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
...
@@ -1376,6 +1425,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1376,6 +1425,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// dK: transform input and output tensor descriptors
// dK: transform input and output tensor descriptors
const
auto
q_grid_desc_m0_k_m1
=
KGradGemmTile_N_K_M
::
MakeQGridDesc_M0_K_M1
(
q_grid_desc_k0_m_k1
);
const
auto
kgrad_grid_desc_n_k
=
KGradGemmTile_N_K_M
::
MakeKGradGridDesc_N_K
(
k_grid_desc_k0_n_k1
);
// dK: A matrix VGPR-to-LDS blockwise copy
auto
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
=
typename
Gemm2
::
ABlockwiseCopy
{
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
MakeAThreadOriginOnBlock_M0_N0_M1_N1_M2_N2_N3_N4
(),
tensor_operation
::
element_wise
::
PassThrough
{}};
// dK: B matrix global-to-LDS blockwise copy
auto
kgrad_gemm_tile_q_blockwise_copy
=
typename
Gemm2
::
template
BBlockwiseCopy
<
decltype
(
q_grid_desc_m0_k_m1
)>(
q_grid_desc_m0_k_m1
,
make_multi_index
(
m_block_data_idx_on_grid
/
Gemm2Params_N_O_M
::
B_M1
,
o_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm2
::
b_block_desc_m0_o_m1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dK: blockwise gemm
/* reuse v_slash_k_grad_blockwise_gemm, v_slash_k_grad_thread_buf */
// dK: C VGPR-to-global copy
const
auto
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
MakeCGridDesc_N0_O0_N1_O1_N2_O2_O3_O4
(
kgrad_grid_desc_n_k
);
const
auto
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
=
Gemm2
::
GetCThreadOriginOnBlock_N0_O0_N1_O1_N2_O2_O3_O4
()
+
make_multi_index
(
I0
,
block_work_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
kgrad_thread_copy_vgpr_to_global
=
typename
Gemm2
::
template
CBlockwiseCopy
<
decltype
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
)>(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_thread_origin_on_grid_n0_o0_n1_o1_n2_o2_o3_o4
,
tensor_operation
::
element_wise
::
PassThrough
{});
//
//
// set up Y dot dY
// set up Y dot dY
...
@@ -1618,38 +1706,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1618,38 +1706,36 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
constexpr
auto
p_block_slice_lengths_m0_n0_m1_n1
=
SubThreadBlock
<
BlockSize
>
gemm2_a_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
typename
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
{};
SubThreadBlock
<
BlockSize
>
p_thread_copy_subgroup
(
s_blockwise_gemm
.
GetWaveIdx
()[
I0
],
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
s_blockwise_gemm
.
GetWaveIdx
()[
I1
]);
constexpr
index_t
num_vgrad_gemm_loop
=
MPerBlock
/
Gemm2Params_N_O_M
::
Sum_M
;
constexpr
index_t
num_gemm2_loop
=
MPerBlock
/
Gemm2Params_N_O_M
::
Sum_M
;
static_assert
(
vgrad_gemm_tile_p_block_slice_window_iterator
.
GetNumOfAccess
()
==
static_assert
(
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetNumOfAccess
()
==
num_gemm2_loop
,
num_vgrad_gemm_loop
,
""
);
""
);
// TODO: tune gemm2 pipeline
// TODO: tune gemm2 pipeline
// dV = P^T * dY
// dV = P^T * dY
vgrad_thread_buf
.
Clear
();
v
_slash_k_
grad_thread_buf
.
Clear
();
static_for
<
0
,
num_
vgrad_
gemm_loop
,
1
>
{}([
&
](
auto
vgrad_
gemm_loop_idx
)
{
// gemm dV
static_for
<
0
,
num_gemm
2
_loop
,
1
>
{}([
&
](
auto
gemm
2
_loop_idx
)
{
// gemm dV
// load VGrad Gemm B
// load VGrad Gemm B
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
vgrad_gemm_tile_ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
ygrad_grid_buf
);
// load VGrad Gemm A
// load VGrad Gemm A
const
auto
p_nd_idx
=
const
auto
p_slice_idx
=
vgrad_gemm_tile_p_block_slice_window_iterator
.
GetIndexTupleOfNumber
(
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
vgrad_gemm_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
constexpr
auto
mwave_range
=
p_slice_idx
[
I2
],
make_tuple
(
p_nd_idx
[
I2
],
p_nd_idx
[
I2
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I2
]);
p_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
constexpr
auto
nwave_range
=
constexpr
auto
nwave_range
=
make_tuple
(
make_tuple
(
p_nd_idx
[
I3
],
p_nd_idx
[
I3
]
+
p_block_slice_lengths_m0_n0_m1_n1
[
I3
]);
p_slice_idx
[
I3
],
p_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
if
(
p_thread_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
{
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
.
Run
(
vgrad_gemm_tile_p_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
p_
nd
_idx
[
I0
],
p_
nd
_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
p_
slice
_idx
[
I0
],
p_
slice
_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
s_slash_p_thread_buf
,
s_slash_p_thread_buf
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
gemm2_a_block_buf
);
gemm2_a_block_buf
);
...
@@ -1665,13 +1751,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1665,13 +1751,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
gemm2_b_block_buf
);
gemm2_b_block_buf
);
block_sync_lds
();
// sync before read
block_sync_lds
();
// sync before read
vgrad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_block_buf
,
vgrad_thread_buf
);
v_slash_k_grad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_block_buf
,
v_slash_k_grad_thread_buf
);
});
// end gemm dV
});
// end gemm dV
// atomic_add dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
vgrad_thread_buf
,
v
_slash_k_
grad_thread_buf
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_grid_buf
);
vgrad_grid_buf
);
...
@@ -1777,6 +1864,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1777,6 +1864,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
}
// end gemm dQ
}
// end gemm dQ
// dK = dS^T * dQ
v_slash_k_grad_thread_buf
.
Clear
();
static_for
<
0
,
num_gemm2_loop
,
1
>
{}([
&
](
auto
gemm2_loop_idx
)
{
// gemm dK
// load KGrad Gemm B
kgrad_gemm_tile_q_blockwise_copy
.
RunRead
(
q_grid_desc_m0_k_m1
,
q_grid_buf
);
// load KGrad Gemm A
const
auto
sgrad_slice_idx
=
Gemm2
::
ASrcBlockSliceWindowIterator
::
GetIndexTupleOfNumber
(
gemm2_loop_idx
);
constexpr
auto
mwave_range
=
make_tuple
(
sgrad_slice_idx
[
I2
],
sgrad_slice_idx
[
I2
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I2
));
constexpr
auto
nwave_range
=
make_tuple
(
sgrad_slice_idx
[
I3
],
sgrad_slice_idx
[
I3
]
+
Gemm2Params_N_O_M
::
ABlockSliceLengths_M0_N0_M1_N1
::
At
(
I3
));
if
(
gemm2_a_copy_subgroup
.
IsBelong
(
mwave_range
,
nwave_range
))
{
kgrad_gemm_tile_sgrad_thread_copy_vgpr_to_lds
.
Run
(
Gemm2
::
a_src_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
sgrad_slice_idx
[
I0
],
sgrad_slice_idx
[
I1
],
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
sgrad_thread_buf
,
Gemm2
::
a_block_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
gemm2_a_block_buf
);
}
// kgrad slice window is moved with MoveSrcSliceWindow() since it is dynamic buffer
// sgrad slice window is moved by loop index
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm2
::
b_block_slice_copy_step
);
block_sync_lds
();
// sync before write
kgrad_gemm_tile_q_blockwise_copy
.
RunWrite
(
Gemm2
::
b_block_desc_m0_o_m1
,
gemm2_b_block_buf
);
block_sync_lds
();
// sync before read
v_slash_k_grad_blockwise_gemm
.
Run
(
gemm2_a_block_buf
,
gemm2_b_block_buf
,
v_slash_k_grad_thread_buf
);
});
// end gemm dK
// atomic_add dK
kgrad_thread_copy_vgpr_to_global
.
Run
(
Gemm2
::
c_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
v_slash_k_grad_thread_buf
,
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
kgrad_grid_buf
);
// move slice window
// move slice window
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
q_grid_desc_k0_m_k1
,
...
@@ -1794,6 +1931,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1794,6 +1931,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
pgrad_gemm_tile_v_blockwise_copy
.
MoveSrcSliceWindow
(
v_grid_desc_o0_n_o1
,
v_grid_desc_o0_n_o1
,
pgrad_gemm_tile_v_block_reset_copy_step
);
// rewind O and step N
pgrad_gemm_tile_v_block_reset_copy_step
);
// rewind O and step N
kgrad_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_m0_k_m1
,
Gemm2
::
b_block_reset_copy_step
);
// rewind M
kgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
kgrad_grid_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
Gemm2
::
c_block_slice_copy_step
);
// step N
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment