Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
383211ef
Commit
383211ef
authored
Dec 21, 2022
by
Anthony Chang
Browse files
rearrange gemm0/gemm1
parent
d13c92bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
160 additions
and
208 deletions
+160
-208
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+160
-208
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
383211ef
...
@@ -149,7 +149,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -149,7 +149,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__host__
__device__
static
constexpr
auto
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
__host__
__device__
static
constexpr
auto
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
const
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
&
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
const
AccThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
&
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
{
{
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a
c
c_thread_desc_k0_m_k1
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to a
_sr
c_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// m0_m1_m2 -> m
// n4 -> k1
// n4 -> k1
...
@@ -248,7 +248,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -248,7 +248,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
K
=
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
);
const
auto
Gemm1N
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
const
auto
Gemm1N
=
v_grid_desc_n0_o_n1
.
GetLength
(
I1
);
// This assumption redues implemention complexity by categorizing 6 separate GEMMs into 3
// This assumption redu
c
es implemention complexity by categorizing 6 separate GEMMs into 3
// types of GEMM operations, therefore some code body can be reused accordingly
// types of GEMM operations, therefore some code body can be reused accordingly
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// P_MNK / dP_MNO Gemm (Gemm0 rcr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
// Y_MON / dQ_MKN Gemm (Gemm1 rrr)
...
@@ -355,7 +355,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -355,7 +355,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
//
P
/ dP Gemm (type 1 rcr)
//
S
/ dP Gemm (type 1 rcr)
struct
Gemm0
struct
Gemm0
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -485,7 +485,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -485,7 +485,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
m0
*
m1
*
m2
>
{};
static
constexpr
auto
AThreadSliceLength_M
=
Number
<
m0
*
m1
*
m2
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
n4
>
{};
static
constexpr
auto
AThreadSliceLength_K1
=
Number
<
n4
>
{};
static
constexpr
auto
a
c
c_thread_desc_k0_m_k1
=
static
constexpr
auto
a
_sr
c_thread_desc_k0_m_k1
=
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
GetA1SrcThreadDescriptor_AK0PerBlock_MPerBlock_AK1
(
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{});
ASrcThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4
{});
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
a_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor_packed
(
...
@@ -500,7 +500,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -500,7 +500,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
using
ABlockwiseCopy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
FloatGemmAcc
,
DataType
,
DataType
,
decltype
(
a
c
c_thread_desc_k0_m_k1
),
decltype
(
a
_sr
c_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
decltype
(
a_thread_desc_k0_m_k1
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
AThreadSliceLengths_K0_M_K1
,
AThreadSliceLengths_K0_M_K1
,
...
@@ -574,6 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -574,6 +574,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
};
};
// dV / dK Gemm (type 3 crr)
// dV / dK Gemm (type 3 crr)
// TODO ANT: refactor into Gemm2
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
template
<
index_t
Sum_M_
=
MPerXdl
*
2
>
struct
VGradGemmTile_N_O_M_
struct
VGradGemmTile_N_O_M_
{
{
...
@@ -652,17 +653,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -652,17 +653,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
Number
<
4
>
{});
Number
<
4
>
{});
}
}
};
};
using
VGradGemmTile_N_O_M
=
VGradGemmTile_N_O_M_
<>
;
// tune later
using
VGradGemmTile_N_O_M
=
VGradGemmTile_N_O_M_
<>
;
// tune later
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
template
<
index_t
BlockSize_
,
index_t
BlockSliceLength_M_
,
index_t
BlockSliceLength_O_
>
struct
YDotYGrad_M_O_
struct
YDotYGrad_M_O_
{
{
static
constexpr
index_t
SrcScalarPerVetor
=
16
/
sizeof
(
DataType
);
static
constexpr
index_t
SrcScalarPerVe
c
tor
=
16
/
sizeof
(
DataType
);
static
constexpr
auto
ThreadClusterLength_O
=
static
constexpr
auto
ThreadClusterLength_O
=
Number
<
BlockSliceLength_O_
/
SrcScalarPerVetor
>
{};
Number
<
BlockSliceLength_O_
/
SrcScalarPerVe
c
tor
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadClusterLength_M
=
Number
<
BlockSize_
/
ThreadClusterLength_O
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVetor
>
{};
static
constexpr
auto
ThreadSliceLength_O
=
Number
<
SrcScalarPerVe
c
tor
>
{};
static
constexpr
auto
ThreadSliceLength_M
=
static
constexpr
auto
ThreadSliceLength_M
=
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
Number
<
BlockSliceLength_M_
*
ThreadClusterLength_O
/
BlockSize_
>
{};
...
@@ -683,8 +683,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -683,8 +683,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct
PGradGemmTile_M_N_O
struct
PGradGemmTile_M_N_O
{
{
// TODO ANT:
// TODO ANT:
//
Should have mad
e all input tensors 2D and transform them into appropriate 3D form in
//
Mak
e all input tensors 2D and transform them into appropriate 3D form in
kernel to make
//
kernel to make
things more concise
- if we can get the compiler to behave
// things more concise
template
<
typename
YGradGridDesc_M0_O_M1_
>
template
<
typename
YGradGridDesc_M0_O_M1_
>
__device__
static
const
auto
__device__
static
const
auto
MakeYGradGridDesc_O0_M_O1
(
const
YGradGridDesc_M0_O_M1_
&
ygrad_grid_desc_m0_o_m1
)
MakeYGradGridDesc_O0_M_O1
(
const
YGradGridDesc_M0_O_M1_
&
ygrad_grid_desc_m0_o_m1
)
...
@@ -758,31 +758,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -758,31 +758,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
}
template
<
typename
SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_
>
__device__
static
const
auto
MakeSGradThreadDesc_N0_M_N1
(
const
SGradThreadDesc_M0_N0_M1_N1_M2_N2_N3_N4_
&
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
)
{
constexpr
auto
m0
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I0
);
constexpr
auto
n0
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I1
);
constexpr
auto
m1
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I2
);
constexpr
auto
n1
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I3
);
constexpr
auto
m2
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I4
);
constexpr
auto
n2
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I5
);
constexpr
auto
n3
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
sgrad_thread_desc_n0_m_n1
=
transform_tensor_descriptor
(
sgrad_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
sgrad_thread_desc_n0_m_n1
;
}
template
<
typename
KGridDesc_K0_N_K1_
>
template
<
typename
KGridDesc_K0_N_K1_
>
__device__
static
const
auto
__device__
static
const
auto
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
MakeKGridDesc_N0_K_N1
(
const
KGridDesc_K0_N_K1_
&
k_grid_desc_k0_n_k1
)
...
@@ -919,11 +894,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -919,11 +894,26 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
Gemm1NPerBlock
);
//
//
// set up
P
/ dP Gemm (type 1 rcr)
// set up
S
/ dP Gemm (type 1 rcr)
//
//
// A matrix blockwise copy
// Gemm0: LDS allocation for A and B: be careful of alignment
auto
a_blockwise_copy
=
auto
gemm0_a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
gemm0_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// Gemm0: gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gemm0_gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
// S: A matrix blockwise copy
auto
s_gemm_tile_q_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
q_grid_desc_k0_m_k1
)>(
q_grid_desc_k0_m_k1
,
q_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
...
@@ -932,8 +922,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -932,8 +922,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
//
S:
B matrix blockwise copy
auto
b
_blockwise_copy
=
auto
s_gemm_tile_k
_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_k0_n_k1
)>(
k_grid_desc_k0_n_k1
,
k_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
...
@@ -942,76 +932,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -942,76 +932,102 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// S: blockwise gemm
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
auto
s_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
// TransposeC
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
auto
s_slash_p_thread_buf
=
s_blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
const
auto
s_gemm_tile_a_block_reset_copy_step
=
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
Gemm0
::
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b_block_space_offset
,
Gemm0
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
a_block_reset_copy_step
=
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
make_multi_index
(
-
q_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
const
auto
b_block_reset_copy_step
=
const
auto
s_gemm_tile_
b_block_reset_copy_step
=
make_multi_index
(
-
k_grid_desc_k0_n_k1
.
GetLength
(
I0
),
NPerBlock
,
0
);
make_multi_index
(
-
k_grid_desc_k0_n_k1
.
GetLength
(
I0
),
NPerBlock
,
0
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopScheduler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
(
q_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
q_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
// dP: transform input and output tensor descriptors
const
auto
ygrad_grid_desc_o0_m_o1
=
PGradGemmTile_M_N_O
::
MakeYGradGridDesc_O0_M_O1
(
ygrad_grid_desc_m0_o_m1
);
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// dP: Gemm A position blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: Gemm B position blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP: blockwise gemm
// we need separate blockwise gemm object because we need separate thread buffer
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
0
,
0
);
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
make_multi_index
(
-
v_grid_desc_o0_n_o1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
index_t
num_o_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
)
*
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I2
))
/
KPerBlock
);
//
//
// set up Y / dQ Gemm (type 2 rrr)
// set up Y / dQ Gemm (type 2 rrr)
//
//
// Note: Y is pre-calculated in forward pass and loaded to backward pass kernel
using
Gemm1
=
using
Gemm1
=
Gemm1
<
decltype
(
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()),
Gemm1
<
decltype
(
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()),
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
decltype
(
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
())
>
;
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
// Gemm1: VGPR allocation for A and LDS allocation for B
constexpr
auto
acc_thread_desc_k0_m_k1
=
Gemm1
::
acc_thread_desc_k0_m_k1
;
auto
gemm1_a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
Gemm1
::
a_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
// A1 matrix in accumulator VGPR, dst of blockwise copy
auto
gemm1_b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
constexpr
auto
a1_thread_desc_k0_m_k1
=
Gemm1
::
a_thread_desc_k0_m_k1
;
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
Gemm1
::
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// B1 matrix in LDS memory, dst of blockwise copy
// dQ: transform input and output tensor descriptors
constexpr
auto
b1_block_desc_bk0_n_bk1
=
Gemm1
::
b_block_desc_bk0_n_bk1
;
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
auto
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
=
QGradGemmTile_M_K_N
::
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
q_grid_desc_k0_m_k1
);
// A
1
matrix blockwise copy
//
dQ: Gemm
A matrix blockwise copy
auto
a1
_blockwise_copy
=
auto
qgrad_gemm_tile_sgrad
_blockwise_copy
=
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B
1
matrix blockwise copy
//
dQ: Gemm
B matrix blockwise copy
auto
b1
_blockwise_copy
=
auto
qgrad_gemm_tile_k
_blockwise_copy
=
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
v
_grid_desc_n0_
o
_n1
)>(
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
k
_grid_desc_n0_
k
_n1
)>(
v
_grid_desc_n0_
o
_n1
,
k
_grid_desc_n0_
k
_n1
,
make_multi_index
(
0
,
o_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
o_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_element_op
,
b
1
_block_desc_bk0_n_bk1
,
Gemm1
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
DataType
>
(
// dQ: blockwise gemm
a1_thread_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
qgrad_blockwise_gemm
=
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
Gemm1
::
GemmKPack
;
auto
gemm1_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
acc1
_thread_buf
=
gemm1
_blockwise_gemm
.
GetCThreadBuffer
();
auto
qgrad
_thread_buf
=
qgrad
_blockwise_gemm
.
GetCThreadBuffer
();
//
//
// Blockwise softmax
// Blockwise softmax
...
@@ -1391,19 +1407,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1391,19 +1407,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
y_thread_data_on_block_idx
;
y_thread_data_on_block_idx
;
// performs double duty for both y and ygrad
// performs double duty for both y and ygrad
auto
yygrad_threadwise_copy
=
auto
yygrad_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
ThreadwiseTensorSliceTransfer_v2
<
DataType
,
DataType
,
DataType
,
DataType
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
YGridDescriptor_MBlock_MPerBlock_OBlock_OPerBlock
,
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
decltype
(
y_thread_desc_m0_m1_o0_o1
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
// SrcVectorDim
3
,
// SrcVectorDim
YDotYGrad_M_O
::
SrcScalarPerVetor
,
// SrcScalarPerVector
YDotYGrad_M_O
::
SrcScalarPerVe
c
tor
,
// SrcScalarPerVector
1
,
// SrcScalarStrideInVector
1
,
// SrcScalarStrideInVector
true
/* ResetCoordAfterRun */
,
true
/* ResetCoordAfterRun */
,
true
/* InvalidElementAsNaN */
>
(
true
/* InvalidElementAsNaN */
>
(
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_grid_desc_mblock_mperblock_oblock_operblock
,
y_thread_data_on_grid_idx
);
y_thread_data_on_grid_idx
);
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
y_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
auto
ygrad_thread_buf
=
typename
YDotYGrad_M_O
::
SrcBufType
{};
...
@@ -1435,80 +1451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1435,80 +1451,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I2
],
// mwave
acc0_thread_origin
[
I4
])};
// mperxdl
acc0_thread_origin
[
I4
])};
// mperxdl
//
// set up dP Gemm (type 1 rcr)
//
// transform input and output tensor descriptors
const
auto
ygrad_grid_desc_o0_m_o1
=
PGradGemmTile_M_N_O
::
MakeYGradGridDesc_O0_M_O1
(
ygrad_grid_desc_m0_o_m1
);
const
auto
v_grid_desc_o0_n_o1
=
PGradGemmTile_M_N_O
::
MakeVGridDesc_O0_N_O1
(
v_grid_desc_n0_o_n1
);
// dP Gemm A position blockwise copy
auto
pgrad_gemm_tile_ygrad_blockwise_copy
=
typename
Gemm0
::
template
ABlockwiseCopy
<
decltype
(
ygrad_grid_desc_o0_m_o1
)>(
ygrad_grid_desc_o0_m_o1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// dP Gemm B position blockwise copy
auto
pgrad_gemm_tile_v_blockwise_copy
=
typename
Gemm0
::
template
BBlockwiseCopy
<
decltype
(
v_grid_desc_o0_n_o1
)>(
v_grid_desc_o0_n_o1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
tensor_operation
::
element_wise
::
PassThrough
{},
Gemm0
::
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
pgrad_blockwise_gemm
=
typename
Gemm0
::
BlockwiseGemm
{};
auto
pgrad_thread_buf
=
pgrad_blockwise_gemm
.
GetCThreadBuffer
();
const
auto
pgrad_gemm_tile_ygrad_block_reset_copy_step
=
make_multi_index
(
-
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
),
0
,
0
);
const
auto
pgrad_gemm_tile_v_block_reset_copy_step
=
make_multi_index
(
-
v_grid_desc_o0_n_o1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
index_t
num_o_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I0
)
*
ygrad_grid_desc_o0_m_o1
.
GetLength
(
I2
))
/
KPerBlock
);
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
auto
y_dot_ygrad_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatGemmAcc
>
(
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
y_dot_ygrad_thread_desc_mblock_mrepeat_mwave_mperxdl
.
GetElementSpaceSize
());
//
// set up dQ Gemm (type 2 rrr)
//
// transform input and output tensor descriptors
const
auto
k_grid_desc_n0_k_n1
=
QGradGemmTile_M_K_N
::
MakeKGridDesc_N0_K_N1
(
k_grid_desc_k0_n_k1
);
auto
qgrad_grid_desc_mblock_mperblock_kblock_kperblock
=
QGradGemmTile_M_K_N
::
MakeQGradGridDesc_MBlock_MPerBlock_KBlock_KPerBlock
(
q_grid_desc_k0_m_k1
);
// dQ Gemm A matrix blockwise copy
auto
qgrad_gemm_tile_sgrad_blockwise_copy
=
typename
Gemm1
::
ABlockwiseCopy
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// dQ Gemm B matrix blockwise copy
auto
qgrad_gemm_tile_k_blockwise_copy
=
typename
Gemm1
::
template
BBlockwiseCopy
<
decltype
(
k_grid_desc_n0_k_n1
)>(
k_grid_desc_n0_k_n1
,
make_multi_index
(
0
,
o_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
qgrad_blockwise_gemm
=
typename
Gemm1
::
BlockwiseGemm
{
make_tuple
(
0
,
0
,
0
,
0
)};
// A_origin
auto
qgrad_thread_buf
=
qgrad_blockwise_gemm
.
GetCThreadBuffer
();
//
//
// calculate Y dot dY
// calculate Y dot dY
//
//
...
@@ -1586,22 +1531,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1586,22 +1531,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
{
continue
;
continue
;
}
}
// P = Q * K^T
// S = Q * K^T
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
q_grid_desc_k0_m_k1
,
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
Gemm0
::
a_block_desc_ak0_m_ak1
,
q_grid_desc_k0_m_k1
,
a_blockwise_copy
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
q_grid_buf
,
s_gemm_tile_q_blockwise_copy
,
a_block_buf
,
q_grid_buf
,
Gemm0
::
a_block_slice_copy_step
,
gemm0_a_block_buf
,
k_grid_desc_k0_n_k1
,
Gemm0
::
a_block_slice_copy_step
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
k_grid_desc_k0_n_k1
,
b_blockwise_copy
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
k_grid_buf
,
s_gemm_tile_k_blockwise_copy
,
b_block_buf
,
k_grid_buf
,
Gemm0
::
b_block_slice_copy_step
,
gemm0_b_block_buf
,
s_blockwise_gemm
,
Gemm0
::
b_block_slice_copy_step
,
s_slash_p_thread_buf
,
s_blockwise_gemm
,
num_k_block_main_loop
);
s_slash_p_thread_buf
,
num_k_block_main_loop
);
// do MNK padding or upper triangular masking
// do MNK padding or upper triangular masking
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
if
constexpr
(
MaskOutUpperTriangle
||
PadN
)
...
@@ -1679,7 +1625,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1679,7 +1625,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_thread_buf
.
Clear
();
vgrad_thread_buf
.
Clear
();
// TODO ANT: single buffer prefetch pipeline
// TODO: tune pipeline
// dV = P^T * dY
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
static_for
<
0
,
num_vgrad_gemm_loop
,
1
>
{}([
&
](
auto
vgrad_gemm_loop_idx
)
{
// gemm dV
// load VGrad Gemm B
// load VGrad Gemm B
ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
ygrad_blockwise_copy
.
RunRead
(
ygrad_grid_desc_m0_o_m1
,
ygrad_grid_buf
);
...
@@ -1714,7 +1661,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1714,7 +1661,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_blockwise_gemm
.
Run
(
p_block_buf
,
ygrad_block_buf
,
vgrad_thread_buf
);
vgrad_blockwise_gemm
.
Run
(
p_block_buf
,
ygrad_block_buf
,
vgrad_thread_buf
);
});
// end gemm dV
});
// end gemm dV
// atomic_add dV
// atomic_add dV
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
vgrad_thread_copy_vgpr_to_global
.
Run
(
vgrad_thread_desc_n0_o0_n1_o1_n2_o2_o3_o4
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -1723,26 +1669,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1723,26 +1669,27 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
vgrad_grid_buf
);
vgrad_grid_buf
);
// gemm dP
// gemm dP
// assume size K == size O so HasMainKBlockLoop is the same
block_sync_lds
();
block_sync_lds
();
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
// dP = dY * V^T
// assume size K == size O so HasMainKBlockLoop is the same
gemm0_gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
ygrad_grid_desc_o0_m_o1
,
ygrad_grid_desc_o0_m_o1
,
Gemm0
::
a_block_desc_ak0_m_ak1
,
// reuse
Gemm0
::
a_block_desc_ak0_m_ak1
,
// reuse
pgrad_gemm_tile_ygrad_blockwise_copy
,
pgrad_gemm_tile_ygrad_blockwise_copy
,
ygrad_grid_buf
,
ygrad_grid_buf
,
a_block_buf
,
// reuse
gemm0_
a_block_buf
,
// reuse
Gemm0
::
a_block_slice_copy_step
,
// reuse
Gemm0
::
a_block_slice_copy_step
,
// reuse
v_grid_desc_o0_n_o1
,
v_grid_desc_o0_n_o1
,
Gemm0
::
b_block_desc_bk0_n_bk1
,
// reuse
Gemm0
::
b_block_desc_bk0_n_bk1
,
// reuse
pgrad_gemm_tile_v_blockwise_copy
,
pgrad_gemm_tile_v_blockwise_copy
,
v_grid_buf
,
v_grid_buf
,
b_block_buf
,
// reuse
gemm0_
b_block_buf
,
// reuse
Gemm0
::
b_block_slice_copy_step
,
// reuse
Gemm0
::
b_block_slice_copy_step
,
// reuse
pgrad_blockwise_gemm
,
pgrad_blockwise_gemm
,
pgrad_thread_buf
,
pgrad_thread_buf
,
num_o_block_main_loop
);
num_o_block_main_loop
);
//
calculate dS from dP
//
dS = P * (dP - Y_dot_dY)
auto
&
sgrad_thread_buf
=
pgrad_thread_buf
;
auto
&
sgrad_thread_buf
=
pgrad_thread_buf
;
constexpr
auto
pgrad_thread_tile_iterator
=
constexpr
auto
pgrad_thread_tile_iterator
=
pgrad_blockwise_gemm
.
MakeCThreadTileIterator
();
pgrad_blockwise_gemm
.
MakeCThreadTileIterator
();
...
@@ -1760,6 +1707,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1760,6 +1707,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
});
});
// gemm dQ
// gemm dQ
// dQ = dS * K
{
{
// TODO: explore using dynamic buffer for a1 thread buffer
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
...
@@ -1776,54 +1724,59 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1776,54 +1724,59 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
block_sync_lds
();
// wait for previous LDS read
block_sync_lds
();
// wait for previous LDS read
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
gemm1_b_block_buf
);
// main body
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
ac
c_thread_desc_k0_m_k1
,
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
Gemm1
::
a_sr
c_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
i
,
Gemm1
::
a_block_slice_copy_step
*
i
,
sgrad_thread_buf
,
sgrad_thread_buf
,
a
1
_thread_desc_k0_m_k1
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a
1
_thread_buf
);
gemm1_
a_thread_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
qgrad_gemm_tile_k_blockwise_copy
.
RunRead
(
k_grid_desc_n0_k_n1
,
k_grid_buf
);
block_sync_lds
();
block_sync_lds
();
qgrad_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
qgrad_thread_buf
);
qgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
qgrad_thread_buf
);
block_sync_lds
();
block_sync_lds
();
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
qgrad_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_n0_k_n1
,
Gemm1
::
b_block_slice_copy_step
);
k_grid_desc_n0_k_n1
,
Gemm1
::
b_block_slice_copy_step
);
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
b
1
_block_desc_bk0_n_bk1
,
qgrad_gemm_tile_k_blockwise_copy
.
RunWrite
(
Gemm1
::
b_block_desc_bk0_n_bk1
,
b
1
_block_buf
);
gemm1_
b_block_buf
);
});
});
}
}
// tail
// tail
{
{
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
qgrad_gemm_tile_sgrad_blockwise_copy
.
Run
(
ac
c_thread_desc_k0_m_k1
,
Gemm1
::
a_sr
c_thread_desc_k0_m_k1
,
Gemm1
::
a_block_slice_copy_step
*
Number
<
num_gemm1_k_block_inner_loop
-
1
>
{},
Gemm1
::
a_block_slice_copy_step
*
Number
<
num_gemm1_k_block_inner_loop
-
1
>
{},
sgrad_thread_buf
,
sgrad_thread_buf
,
a
1
_thread_desc_k0_m_k1
,
Gemm1
::
a_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a
1
_thread_buf
);
gemm1_
a_thread_buf
);
block_sync_lds
();
block_sync_lds
();
qgrad_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
qgrad_thread_buf
);
qgrad_blockwise_gemm
.
Run
(
gemm1_a_thread_buf
,
gemm1_b_block_buf
,
qgrad_thread_buf
);
}
}
}
// end gemm dQ
}
// end gemm dQ
// move slice window
// move slice window
a_blockwise_copy
.
MoveSrcSliceWindow
(
q_grid_desc_k0_m_k1
,
s_gemm_tile_q_blockwise_copy
.
MoveSrcSliceWindow
(
a_block_reset_copy_step
);
// rewind K
q_grid_desc_k0_m_k1
,
b_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
s_gemm_tile_a_block_reset_copy_step
);
// rewind K
b_block_reset_copy_step
);
// rewind K and step N
s_gemm_tile_k_blockwise_copy
.
MoveSrcSliceWindow
(
k_grid_desc_k0_n_k1
,
s_gemm_tile_b_block_reset_copy_step
);
// rewind K and step N
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_blockwise_copy
.
MoveSrcSliceWindow
(
ygrad_grid_desc_m0_o_m1
,
ygrad_block_reset_copy_step
);
// rewind M
ygrad_block_reset_copy_step
);
// rewind M
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
vgrad_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
@@ -1836,7 +1789,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1836,7 +1789,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// TODO ANT:
// shuffle dQ and write
// shuffle dQ and write
{
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
...
@@ -1848,12 +1800,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1848,12 +1800,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// TODO: hacky, fix it!
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
gemm1
_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
qgrad
_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
// TODO: hacky, fix it!
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
=
gemm1
_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
qgrad
_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I0
);
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I1
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp
.
GetLength
(
I1
);
...
@@ -1893,7 +1845,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -1893,7 +1845,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
const
auto
c_thread_mtx_on_block
=
gemm1
_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
qgrad
_blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment