Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
08e44540
Commit
08e44540
authored
Jan 13, 2025
by
coderfeli
Browse files
debugs
parent
a75f162b
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
128 additions
and
48 deletions
+128
-48
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+4
-2
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
..._tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
+2
-2
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+5
-0
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
+3
-0
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+36
-0
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
.../flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
+4
-4
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
+6
-6
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+36
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+32
-32
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
08e44540
...
@@ -25,7 +25,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -25,7 +25,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
else
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
;
// using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
...
@@ -37,7 +38,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -37,7 +38,8 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
0
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
;
// using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0>;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
// clang-format on
// clang-format on
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_api_traits.hpp
View file @
08e44540
...
@@ -33,7 +33,7 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -33,7 +33,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
YSmoothScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
YSmoothScaleDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
TopkWeightDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
TopkWeightDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
using
IndexDataType
=
ck_tile
::
remove_cvref_t
<
typename
TypeConfig
::
IndexDataType
>
;
// S<32, 1024, 128, 128>, S<1, 4, 1>, S<16, 16, 32>
// S<32, 1024
|512
, 128, 128>, S<1, 4, 1>, S<16, 16, 32>
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token
static
constexpr
ck_tile
::
index_t
BT_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
0
>
{});
// block token
static
constexpr
ck_tile
::
index_t
BI_
=
static
constexpr
ck_tile
::
index_t
BI_
=
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate
BlockTIle_
::
at
(
ck_tile
::
number
<
1
>
{});
// block intermediate
...
@@ -44,7 +44,7 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -44,7 +44,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
// S<1, 4, 1>
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
// S<1, 4, 1>
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
// S<16, 16, 32>
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
// S<16, 16, 32>
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
// 32, 128, 512
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
// 32, 128, 512
|256
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
/// S<1, 4, 1>
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
/// S<1, 4, 1>
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
// S<16, 16, 32>
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
// S<16, 16, 32>
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
View file @
08e44540
...
@@ -10,8 +10,13 @@
...
@@ -10,8 +10,13 @@
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
example/ck_tile/15_fused_moe/instances/fused_moegemm_fp16_m32.cpp
View file @
08e44540
...
@@ -13,5 +13,8 @@ template float fused_moegemm_<
...
@@ -13,5 +13,8 @@ template float fused_moegemm_<
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
1024
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
0
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
example/ck_tile/15_fused_moe/main.cpp
View file @
08e44540
...
@@ -59,6 +59,42 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
...
@@ -59,6 +59,42 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
}
}
return
t
;
return
t
;
}
}
template
<
typename
T
>
auto
shuffle_moe_weight_gateup
(
const
ck_tile
::
HostTensor
<
T
>&
t
,
std
::
string
mfma_dtype
,
int
mfma_type
=
0
)
{
assert
(
t
.
get_lengths
().
size
()
==
3
);
int
b_
=
t
.
get_lengths
()[
0
];
int
n_
=
t
.
get_lengths
()[
1
];
int
k_
=
t
.
get_lengths
()[
2
];
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
0
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
16
,
2
,
8
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"bf16"
||
mfma_dtype
==
"fp16"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
2
,
n_
/
512
,
16
,
16
,
k_
/
32
,
4
,
8
});
// ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
// return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
2
,
1
,
3
,
5
,
6
,
4
,
7
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
0
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
32
,
32
,
k_
/
32
,
2
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
else
if
((
mfma_dtype
==
"int8"
||
mfma_dtype
==
"fp8"
)
&&
mfma_type
==
1
)
{
ck_tile
::
HostTensor
<
T
>
t_view
({
b_
,
n_
/
16
,
16
,
k_
/
64
,
4
,
16
});
std
::
copy
(
t
.
begin
(),
t
.
end
(),
t_view
.
begin
());
return
ck_tile
::
reference_permute
(
t_view
,
{
0
,
1
,
3
,
4
,
2
,
5
});
}
return
t
;
}
template
<
typename
IndexType
>
template
<
typename
IndexType
>
void
topid_unique_gen
(
void
topid_unique_gen
(
...
...
include/ck_tile/ops/flatmm/block/flatmm_sn_32x128x256_1x4x1_16x16x32_itl.hpp
View file @
08e44540
...
@@ -427,10 +427,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
...
@@ -427,10 +427,10 @@ struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl : public FlatmmSn_32x128x512_
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
//
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_o
]
"s"
(
tile_stride_o_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
[
s_tile_os_b
]
"s"
(
tile_stride_b_bytes
),
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
View file @
08e44540
...
@@ -63,7 +63,7 @@ struct FusedMoeGemmShape
...
@@ -63,7 +63,7 @@ struct FusedMoeGemmShape
// S<32, 512, 128>, S<1, 4, 1>, S<16, 16, 32>
// S<32, 512, 128>, S<1, 4, 1>, S<16, 16, 32>
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
//32
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
//32
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
//
512
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
//
256
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
// 128
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
// 128
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
number
<
0
>
{});
// 1
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
number
<
0
>
{});
// 1
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
number
<
1
>
{});
// 4
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
number
<
1
>
{});
// 4
...
@@ -73,18 +73,18 @@ struct FusedMoeGemmShape
...
@@ -73,18 +73,18 @@ struct FusedMoeGemmShape
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
// 32
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
// 32
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
// 64
static
constexpr
index_t
ThreadPerBlock_K0
=
Warp_K0
*
WarpPerBlock_K0
;
static
constexpr
index_t
ThreadPerBlock_K0
=
Warp_K0
*
WarpPerBlock_K0
;
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
// 2
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
// 2
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
//
8
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
//
4
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
// 4
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
// 4
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
//32
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
//32
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
//128
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
//128
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
//
512
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
//
256
static
constexpr
index_t
WarpPerBlock_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
// 1
static
constexpr
index_t
WarpPerBlock_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
// 1
static
constexpr
index_t
WarpPerBlock_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
// 4
static
constexpr
index_t
WarpPerBlock_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
// 4
static
constexpr
index_t
WarpPerBlock_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
// 1
static
constexpr
index_t
WarpPerBlock_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
// 1
...
@@ -100,7 +100,7 @@ struct FusedMoeGemmShape
...
@@ -100,7 +100,7 @@ struct FusedMoeGemmShape
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
// 2
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
// 2
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
// 2
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
// 2
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
//
16
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
//
8
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
...
@@ -118,7 +118,7 @@ struct FusedMoeGemmShape
...
@@ -118,7 +118,7 @@ struct FusedMoeGemmShape
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
// 512
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
// 512
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
// 8
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
// 8
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
//
16
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
//
8
static_assert
(
Block_W0
==
Block_W1
);
static_assert
(
Block_W0
==
Block_W1
);
// static_assert(Block_Nr0 == Block_Kr1);
// static_assert(Block_Nr0 == Block_Kr1);
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
08e44540
...
@@ -243,7 +243,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -243,7 +243,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
{
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
//
constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
constexpr
index_t
hidden_radio_0
=
Problem
::
Traits
::
IsGateOnly
?
1
:
2
;
using
S_
=
typename
Problem
::
BlockShape
;
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
{
...
@@ -251,7 +251,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -251,7 +251,7 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// number<S_::Repeat_N0>{}.eee();
// number<S_::Repeat_N0>{}.eee();
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N0
,
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N0
,
S_
::
WarpPerBlock_K0
,
S_
::
WarpPerBlock_K0
,
S_
::
Repeat_N0
,
///
hidden_radio_0,
S_
::
Repeat_N0
*
hidden_radio_0
,
S_
::
Repeat_K0
,
S_
::
Repeat_K0
,
get_warp_size
(),
get_warp_size
(),
GetAlignment_G
<
Problem
>
()
>
();
GetAlignment_G
<
Problem
>
()
>
();
...
@@ -804,6 +804,20 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -804,6 +804,20 @@ struct FusedMoeGemmPipelineFlatmmPolicy
{
{
return
Flatmm_32x512x128_1x4x1_16x16x32_FP16
{};
return
Flatmm_32x512x128_1x4x1_16x16x32_FP16
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
Flatmm_32x256x128_1x4x1_16x16x32_BF16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
fp16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
Flatmm_32x256x128_1x4x1_16x16x32_FP16
{};
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -851,6 +865,26 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -851,6 +865,26 @@ struct FusedMoeGemmPipelineFlatmmPolicy
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl
{};
return
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
256
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
T_
::
PipeInterleave
==
true
)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return
FlatmmSn_32x128x256_1x4x1_16x16x32_BF16_itl
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
256
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
&&
T_
::
PipeInterleave
==
true
)
{
// return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
return
FlatmmSn_32x128x256_1x4x1_16x16x32_FP16_itl
{};
}
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
08e44540
...
@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -73,7 +73,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
(
IsGateOnly
?
1
:
2
)
;
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
;
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
}
...
@@ -168,7 +168,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -168,7 +168,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t
intermediate_tile_id
)
index_t
intermediate_tile_id
)
{
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
/
hidden_radio_0
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
...
@@ -178,13 +178,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -178,13 +178,13 @@ struct FusedMoeGemmPipeline_FlatmmUk
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
*
hidden_radio_0
;
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
// nr*kr*w
// nr*kr*w
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
BlockShape
::
Block_Nr0
*
hidden_radio_0
);
// intermediate_tile_id * Block_N / (N in W)
index_t
interm_idx_kr1
=
__builtin_amdgcn_readfirstlane
(
index_t
interm_idx_kr1
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
intermediate_tile_id
*
...
@@ -218,7 +218,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -218,7 +218,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
g_window_
=
make_tile_window_linear_raw
(
auto
g_window_
=
make_tile_window_linear_raw
(
g_view_
,
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
make_tuple
(
number
<
BlockShape
::
Block_Nr0
*
hidden_radio_0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
{
0
,
0
,
0
},
...
@@ -351,32 +351,32 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -351,32 +351,32 @@ struct FusedMoeGemmPipeline_FlatmmUk
block_sync_lds
();
block_sync_lds
();
// up
// up
if
(
!
IsGateOnly
)
//
if(!IsGateOnly)
{
//
{
// up ptr. add hafl expoert_stride_0 as offset.
//
// up ptr. add hafl expoert_stride_0 as offset.
auto
u_win
=
gu_win_gen
(
shared_intermediate_size_0
*
kargs
.
hidden_size
);
//
auto u_win = gu_win_gen(shared_intermediate_size_0 * kargs.hidden_size);
auto
u_res
=
u_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
//
auto u_res = u_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto
u_coords
=
//
auto u_coords =
generate_tuple
([
&
](
auto
i
)
{
return
u_win
.
cached_coords_
[
i
].
get_offset
();
},
//
generate_tuple([&](auto i) { return u_win.cached_coords_[i].get_offset(); },
number
<
decltype
(
u_win
)
::
NumAccess_NonLinear
>
{});
//
number<decltype(u_win)::NumAccess_NonLinear>{});
// reuse UK0
//
// reuse UK0
auto
uk_0_u
=
Policy
::
template
GetUK_0
<
Problem
>();
//
auto uk_0_u = Policy::template GetUK_0<Problem>();
auto
acc_0_u
=
uk_0_u
(
a_res
,
//
auto acc_0_u = uk_0_u(a_res,
a_coords
,
//
a_coords,
u_res
,
//
u_res,
u_coords
,
//
u_coords,
smem
,
//
smem,
kargs
.
hidden_size
,
//
kargs.hidden_size,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
//
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
//
BlockShape::Block_Kr0 *
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
//
BlockShape::Block_W0); // tile offset for B matrix each unroll
// elementwise mul gate*up.
//
// elementwise mul gate*up.
sweep_tile
(
//
sweep_tile(
y_pre
,
//
y_pre,
[
&
](
auto
idx0
)
{
y_pre
(
idx0
)
=
y_pre
(
idx0
)
*
acc_0_u
(
idx0
);
},
//
[&](auto idx0) { y_pre(idx0) = y_pre(idx0) * acc_0_u(idx0); },
sequence
<
1
,
1
>
{});
//
sequence<1, 1>{});
block_sync_lds
();
//
block_sync_lds();
}
//
}
store_tile
(
bridge_sst_win
,
cast_tile
<
YDataType
>
(
y_pre
));
store_tile
(
bridge_sst_win
,
cast_tile
<
YDataType
>
(
y_pre
));
block_sync_lds
();
block_sync_lds
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment