Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4841d991
"...resnet50_tensorflow.git" did not exist on "ca15f5d9dac9a0f620dcea13c412f690b04e87ee"
Commit
4841d991
authored
Dec 01, 2023
by
Adam Osewski
Browse files
Multiple changes to gridwise gemm.
parent
ad0e4083
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
151 additions
and
84 deletions
+151
-84
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+151
-84
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
4841d991
...
...
@@ -441,7 +441,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
[[
maybe_unused
]]
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
)
{
...
...
@@ -449,10 +449,117 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
K
,
N
,
StrideB
,
KBatch
);
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
ignore
=
StrideDs
;
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
M
%
MPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
N
%
NPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
KBatch
*
KPerBlock
;
if
(
!
(
K
%
K_t
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K value is not a multiple of ! KBatch * KPerBlock: "
<<
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
else
{
if
(
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
...
...
@@ -461,6 +568,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
#if DEBUG_LOG
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
" K0Padded: "
<<
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
return
false
;
}
...
...
@@ -524,22 +637,18 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__host__
__device__
static
auto
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
index_t
grid_size
)
{
const
auto
w_desc_grid_i1_mperb_nperb
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
NPerBlock
,
I1
.
value
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
I1
.
value
,
MPerBlock
));
}
}();
return
w_desc_grid_i1_mperb_nperb
;
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
NPerBlock
,
I1
.
value
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
I1
.
value
,
MPerBlock
));
}
}
// TODO: we should refactor out all those common Make... descriptors to sth like
...
...
@@ -700,73 +809,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// TODO Need to do CShuffle already here:
__device__
void
StorePartials
(
void
*
__restrict__
p_workspace
)
{
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const
auto
workspace_grid_desc_m0_n0_m1_n1
=
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
get_grid_size
());
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1
.
GetElementSpaceSize
());
const
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
_tmp
=
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
// M0 = grid_size -> MRepeats
// N0 = 1 -> NRepeats
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1
,
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
make_pass_through_transform
(
w_grid_n0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
3
,
5
,
9
>
{}));
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
make_merge_transform
(
make_tuple
(
w_grid_m0
,
M0
)),
// MRepeats (grid)
make_merge_transform
(
make_tuple
(
w_grid_n0
,
N0
)),
// NRepeats (grid)
make_pass_through_transform
(
M1
),
// MWave
make_pass_through_transform
(
N1
),
// NWave
make_pass_through_transform
(
M2
),
// mfma_instr.num_groups_per_blk
make_pass_through_transform
(
M3
),
// mfma_instr.num_input_blks
make_pass_through_transform
(
M4
),
// mfma_instr.group_size
make_pass_through_transform
(
N2
)),
// mfma_instr.num_threads_per_blk
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
...
@@ -810,7 +873,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
1
,
// DstScalarStrideInVector
true
>
{
// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
)
*
MXdlPerWave
,
make_multi_index
(
m_thread_data_on_block_idx
[
I0
]
,
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
...
...
@@ -827,14 +890,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
w_grid_buf
);
}
__device__
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
in
dex
_t
reduce_count
)
__device__
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
u
in
t32
_t
reduce_count
)
{
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// using CThreadBufferT = ck::remove_reference_t<decltype(c_thread_buf)>;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
(),
...
...
@@ -957,10 +1019,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
AccDataType
>
;
constexpr
auto
partial_acc_load_step
=
make_multi_index
(
MXdlPerWave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for
(
int
i_t
=
1
;
i_t
<
reduce_count
;
++
i_t
)
for
(
u
int
32_t
i_t
=
1
;
i_t
<
reduce_count
;
++
i_t
)
{
acc_buf
.
Clear
();
acc_load
.
Run
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
...
...
@@ -971,6 +1035,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_for
<
0
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
(),
1
>
{}(
[
&
](
auto
i_vec
)
{
Accumulation
::
Calculate
(
c_thread_buf
(
i_vec
),
acc_buf
[
i_vec
]);
});
acc_load
.
MoveSrcSliceWindow
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
partial_acc_load_step
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment