Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a75f162b
Commit
a75f162b
authored
Jan 12, 2025
by
coderfeli
Browse files
add files
parent
59c05300
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1354 additions
and
39 deletions
+1354
-39
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+7
-7
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+8
-1
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
+511
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
+794
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
+28
-27
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+6
-4
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
View file @
a75f162b
...
@@ -33,20 +33,20 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -33,20 +33,20 @@ struct fmoe_ // traits, ugly name, only used for internal
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
// S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token
static
constexpr
ck_tile
::
index_t
BI_
=
static
constexpr
ck_tile
::
index_t
BI_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden
static
constexpr
ck_tile
::
index_t
BH_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
2
>
{});
// block hidden
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down
static
constexpr
ck_tile
::
index_t
BD_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
3
>
{});
// block down
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
/
(
GateOnly_
?
1
:
2
),
BH_
>
;
using
BlockTile_0
=
ck_tile
::
sequence
<
BT_
,
BI_
/
(
GateOnly_
?
1
:
2
),
BH_
>
;
//32, 512, 128
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
// S<1, 4, 1>
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
// S<16, 16, 32>
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
// 32, 128, 512
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
/// S<1, 4, 1>
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
// S<16, 16, 32>
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
static
constexpr
ck_tile
::
index_t
FusedQuant
=
FusedQuant_
;
...
...
example/ck_tile/15_fused_moe/main.cpp
View file @
a75f162b
...
@@ -285,7 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -285,7 +285,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
else
{
{
topid_unique_gen
<
IndexDataType
>
(
topk_ids_host
.
mData
,
tokens
,
topk
,
experts
,
11913
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
topk_ids_host
.
mData
.
size
());
i
++
)
{
topk_ids_host
.
mData
[
i
]
=
i
%
4
;
}
// topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
}
// leave it here for future debug purpose
// leave it here for future debug purpose
...
@@ -442,6 +445,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -442,6 +445,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
topk
,
topk
,
gate_only
);
gate_only
);
sorted_token_ids_host
.
savetxt
(
"sorted_token_ids_host.txt"
,
"int"
);
sorted_expert_ids_host
.
savetxt
(
"sorted_expert_ids_host.txt"
,
"int"
);
num_sorted_tiles_host
.
savetxt
(
"num_sorted_tiles_host.txt"
,
"int"
);
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
// o_dev.savetxt("gpu-out.txt", "float");
// o_dev.savetxt("gpu-out.txt", "float");
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
auto
[
rtol
,
atol
]
=
get_elimit
<
ADataType
>
();
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
0 → 100644
View file @
a75f162b
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x256_1x4x1_16x16x16_itl.inc
0 → 100644
View file @
a75f162b
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
View file @
a75f162b
...
@@ -61,15 +61,16 @@ struct FusedMoeGemmShape
...
@@ -61,15 +61,16 @@ struct FusedMoeGemmShape
// TODO: we don't support half warps aound to 1 warp here
// TODO: we don't support half warps aound to 1 warp here
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
// S<32, 512, 128>, S<1, 4, 1>, S<16, 16, 32>
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
//32
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
//512
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
// 128
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
number
<
0
>
{});
// 1
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
number
<
1
>
{});
// 4
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
number
<
2
>
{});
// 1
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
// 16
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
// 16
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
// 32
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
...
@@ -77,19 +78,19 @@ struct FusedMoeGemmShape
...
@@ -77,19 +78,19 @@ struct FusedMoeGemmShape
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
// 2
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
// 8
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
// 4
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
//32
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
//128
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
//512
static
constexpr
index_t
WarpPerBlock_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
// 1
static
constexpr
index_t
WarpPerBlock_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
// 4
static
constexpr
index_t
WarpPerBlock_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
// 1
static
constexpr
index_t
Warp_M1
=
WarpTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_M1
=
WarpTile_1
::
at
(
number
<
0
>
{});
// 16
static
constexpr
index_t
Warp_N1
=
WarpTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_N1
=
WarpTile_1
::
at
(
number
<
1
>
{});
// 16
static
constexpr
index_t
Warp_K1
=
WarpTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_K1
=
WarpTile_1
::
at
(
number
<
2
>
{});
// 32
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
...
@@ -97,9 +98,9 @@ struct FusedMoeGemmShape
...
@@ -97,9 +98,9 @@ struct FusedMoeGemmShape
static_assert
(
Block_M1
%
ThreadPerBlock_M1
==
0
);
static_assert
(
Block_M1
%
ThreadPerBlock_M1
==
0
);
static_assert
(
Block_N1
%
ThreadPerBlock_N1
==
0
);
static_assert
(
Block_N1
%
ThreadPerBlock_N1
==
0
);
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
// 2
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
// 2
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
// 16
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
...
@@ -115,9 +116,9 @@ struct FusedMoeGemmShape
...
@@ -115,9 +116,9 @@ struct FusedMoeGemmShape
static
constexpr
index_t
Block_W0
=
Warp_N0
*
Warp_K0
;
static
constexpr
index_t
Block_W0
=
Warp_N0
*
Warp_K0
;
static
constexpr
index_t
Block_Nr0
=
Block_N0
/
Warp_N0
;
static
constexpr
index_t
Block_Nr0
=
Block_N0
/
Warp_N0
;
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
// 512
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
// 8
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
// 16
static_assert
(
Block_W0
==
Block_W1
);
static_assert
(
Block_W0
==
Block_W1
);
// static_assert(Block_Nr0 == Block_Kr1);
// static_assert(Block_Nr0 == Block_Kr1);
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
a75f162b
...
@@ -199,6 +199,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -199,6 +199,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
},
number
<
row_ids_a
.
size
()
>
{});
number
<
row_ids_a
.
size
()
>
{});
if
(
row_ids_a
[
0
]
>=
kargs
.
num_tokens
)
return
;
auto
a_res
=
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
...
@@ -238,7 +240,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -238,7 +240,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
// n/16, k/32, 512
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
kAlignmentD
>
{},
number
<
kAlignmentD
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -264,13 +266,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -264,13 +266,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
d_coords
=
[
&
]()
{
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
4
;
constexpr
index_t
Kr0_
=
BlockShape
::
Block_Kr1
/
Kr1_
;
//
4
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
8
;
constexpr
index_t
Kv_
=
8
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
// 512
constexpr
index_t
num_offsets_
=
Nr_
*
Kr0_
;
constexpr
index_t
num_offsets_
=
Nr_
*
Kr0_
;
// 8
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
shared_intermediate_size_1
*
shared_intermediate_size_1
*
Nl_
;
// Kr0_ * Kr1_ * W_;
Nl_
;
// Kr0_ * Kr1_ * W_;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment