Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
98e4c0ce
Commit
98e4c0ce
authored
Aug 03, 2022
by
Anthony Chang
Browse files
add BlockwiseGemmXdlops_v2 while exploring an unified approach
parent
eceea10a
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
437 additions
and
148 deletions
+437
-148
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+435
-129
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+2
-19
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
98e4c0ce
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
98e4c0ce
...
...
@@ -101,22 +101,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
BlockDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
(
const
BlockDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
BlockDesc_K0_MN_K1
{}.
GetLength
(
I0
);
constexpr
index_t
K1
=
BlockDesc_K0_MN_K1
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
BlockDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
...
...
@@ -453,8 +437,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// TODO ANT: to refactor: blockwise gemm output layout
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
...
...
@@ -603,7 +586,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
math
::
lcm
(
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
,
B1K1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_
k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_
v2
<
BlockSize
,
FloatAB
,
FloatGemmAcc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment