Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
45131629
Commit
45131629
authored
Nov 07, 2024
by
carlushuang
Browse files
update pipeline
parent
f09dc1f3
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
150 additions
and
70 deletions
+150
-70
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
+1
-1
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
...ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
+3
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
.../ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
+76
-69
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+69
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+1
-0
No files found.
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
View file @
45131629
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
256
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
;
fused_moegemm_
<
t_
>
(
s
,
a
);
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
// clang-format on
// clang-format on
...
...
example/ck_tile/15_fused_moe/instances/fused_moegemm_bf16_m32.cpp
View file @
45131629
...
@@ -11,4 +11,7 @@ template float fused_moegemm_<
...
@@ -11,4 +11,7 @@ template float fused_moegemm_<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
512
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
256
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
16
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm.hpp
View file @
45131629
...
@@ -51,6 +51,11 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -51,6 +51,11 @@ struct FusedMoeGemmPipeline_Flatmm
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
return
Problem
::
kBlockPerCu
;
...
@@ -146,10 +151,14 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -146,10 +151,14 @@ struct FusedMoeGemmPipeline_Flatmm
auto
a_win
=
make_tile_window_linear
(
auto
a_win
=
make_tile_window_linear
(
a_window_
,
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
a_window_
,
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
g_win
=
make_tile_window_linear
(
auto
g_win
=
g_window_
,
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>());
make_tile_window_linear
(
g_window_
,
auto
d_win
=
make_tile_window_linear
(
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
d_window_
,
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
sequence
<
0
,
1
,
1
>
{});
auto
d_win
=
make_tile_window_linear
(
d_window_
,
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
auto
o_win
=
make_tile_window_linear
(
auto
o_win
=
make_tile_window_linear
(
o_window_
,
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
o_window_
,
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
...
@@ -239,8 +248,8 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -239,8 +248,8 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr
auto
issues_a
=
number
<
a_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_a
=
number
<
a_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_g
=
number
<
g_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_g
=
number
<
g_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_d
=
number
<
d_win
.
get_num_of_access
()
>
{};
//
constexpr auto issues_d = number<d_win.get_num_of_access()>{};
constexpr
auto
issues_o
=
number
<
o_win
.
get_num_of_access
()
>
{};
//
constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr
auto
issues_gemm0
=
constexpr
auto
issues_gemm0
=
number
<
BlockShape
::
Repeat_M0
*
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_K0
*
number
<
BlockShape
::
Repeat_M0
*
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_K0
*
warp_gemm_0
.
get_num_of_access
()
>
{};
warp_gemm_0
.
get_num_of_access
()
>
{};
...
@@ -431,12 +440,7 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -431,12 +440,7 @@ struct FusedMoeGemmPipeline_Flatmm
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
constexpr
auto
c_sld_a_0
=
MAKE_SC
();
constexpr
auto
c_sld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
...
@@ -480,36 +484,33 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -480,36 +484,33 @@ struct FusedMoeGemmPipeline_Flatmm
};
};
auto
pipeline_gemm0_tail
=
[
&
]()
{
auto
pipeline_gemm0_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
index_t
mfma_per_gld_g
=
total_loops
/
issues_g
;
// BlockShape::Repeat_M0;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
// constexpr index_t mfma_per_gld_a = total_loops / issues_a;
static_assert
(
sr
.
size
()
==
total_loops
);
// constexpr index_t mfma_per_sld_a = total_loops / issues_sld_a;
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 0
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_g
==
0
)
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
{
gld_g
(
gs
[
I1
],
number
<
i_issue
/
mfma_per_gld_g
>
{});
move_g
();
}
// if constexpr (i_issue % mfma_per_gld_a == 0)
// gld_a(a_sst_win0, number<i_issue / mfma_per_gld_a>{});
// if constexpr(i_issue % mfma_per_sld_a == 0)
if
constexpr
(
slot
&
GLD_B
)
// {
gld_g
(
gs
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
// block_sync_load_raw(a_sst_win0.get_num_of_access());
// sld_a(as[I1], a_sld_win1, number<i_issue / mfma_per_sld_a>{});
// }
});
});
// if cycle_mfma>gld_a sync here
block_sync_load_raw
(
issues_g
);
block_sync_load_raw
(
issues_g
);
sld_a
(
as
[
I1
],
a_sld_win1
,
NEG1
);
sld_a
(
as
[
I1
],
a_sld_win1
,
NEG1
);
// compute buffer 1
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
,
TRUE
);
// last gemm has nop
constexpr
auto
last_nop
=
[
&
]()
{
if
constexpr
(
i_issue
==
(
total_loops
-
1
))
return
TRUE
;
else
return
FALSE
;
}();
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
,
last_nop
);
// last gemm has nop
});
});
};
};
...
@@ -527,73 +528,79 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -527,73 +528,79 @@ struct FusedMoeGemmPipeline_Flatmm
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto
pipeline_gemm1
=
[
&
]()
{
auto
pipeline_gemm1
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
mfma_per_gld_d
=
total_loops
/
issues_d
;
// BlockShape::Repeat_M0
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>()
;
constexpr
index_t
mfma_per_atm_o
=
total_loops
/
issues_o
;
static_assert
(
sr
.
size
()
=
=
total_loops
)
;
// compute buffer 1
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
constexpr
auto
c_gst_o_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_1
=
MAKE_SC
();
constexpr
auto
c_gst_o_1
=
MAKE_SC
();
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
{
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
gld_d
(
ds
[
I0
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
move_d
();
}
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
if
constexpr
(
slot
&
GST_O
)
{
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
i_issue
/
mfma_per_atm_o
>
{});
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_0
,
i_issue
)
>
{});
}
}
});
});
move_d
();
// move_o();
// compute buffer
0
// compute buffer
1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
{
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
gld_d
(
ds
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_1
,
i_issue
)
>
{});
move_d
();
}
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
if
constexpr
(
slot
&
GST_O
)
{
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
atomic_add_o
(
out
,
number
<
i_issue
/
mfma_per_atm_o
>
{});
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_1
,
i_issue
)
>
{});
}
}
});
});
move_d
();
};
};
auto
pipeline_gemm1_head
=
[
&
]()
{
auto
pipeline_gemm1_head
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
mfma_per_gld_d
=
total_loops
/
issues_d
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 0
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
{
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I1
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
gld_d
(
ds
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
move_d
();
}
});
});
move_d
();
};
};
auto
pipeline_gemm1_tail
=
[
&
]()
{
auto
pipeline_gemm1_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
index_t
mfma_per_gld_d
=
total_loops
/
issues_d
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
constexpr
index_t
mfma_per_atm_o
=
total_loops
/
issues_o
;
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gst_o_0
=
MAKE_SC
();
// compute buffer 1
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
if
constexpr
(
i_issue
%
mfma_per_gld_d
==
0
)
{
gld_d
(
ds
[
I0
],
number
<
i_issue
/
mfma_per_gld_d
>
{});
move_d
();
}
if
constexpr
(
i_issue
%
mfma_per_atm_o
==
0
)
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GST_O
)
{
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
i_issue
/
mfma_per_atm_o
>
{});
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_0
,
i_issue
)
>
{});
}
}
});
});
{
{
...
@@ -620,7 +627,7 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -620,7 +627,7 @@ struct FusedMoeGemmPipeline_Flatmm
// we manually unroll double buffer inside hot loop
// we manually unroll double buffer inside hot loop
const
index_t
iters_0
=
(
num_blocks_k0
-
2
)
/
2
;
const
index_t
iters_0
=
(
num_blocks_k0
-
2
)
/
2
;
index_t
i_0
=
0
;
index_t
i_0
=
0
;
// (void)i_0; (void)iters_0; (void)pipeline_gemm0;
while
(
i_0
++
<
iters_0
)
while
(
i_0
++
<
iters_0
)
{
{
pipeline_gemm0
();
pipeline_gemm0
();
...
@@ -630,7 +637,7 @@ struct FusedMoeGemmPipeline_Flatmm
...
@@ -630,7 +637,7 @@ struct FusedMoeGemmPipeline_Flatmm
pipeline_bridge
();
pipeline_bridge
();
const
index_t
iters_1
=
(
num_blocks_n1
-
2
)
/
2
;
const
index_t
iters_1
=
(
num_blocks_n1
-
2
)
/
2
;
index_t
i_1
=
0
;
index_t
i_1
=
0
;
// (void) i_1; (void)iters_1; (void)pipeline_gemm1;
pipeline_gemm1_head
();
pipeline_gemm1_head
();
while
(
i_1
++
<
iters_1
)
while
(
i_1
++
<
iters_1
)
{
{
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
45131629
...
@@ -641,6 +641,75 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -641,6 +641,75 @@ struct FusedMoeGemmPipelineFlatmmPolicy
return
seq_all
;
return
seq_all
;
// clang-format on
// clang-format on
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 0
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 1
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 2
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
>
{};
// 3
return
seq_all
;
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSequencer_1
()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M1
==
32
&&
S_
::
Warp_N1
==
32
&&
S_
::
Warp_K1
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 0
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 1
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 2
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 3
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 4
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 5
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 6
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 7
return
seq_all
;
// clang-format on
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M1
==
32
&&
S_
::
Warp_N1
==
32
&&
S_
::
Warp_K1
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 0
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 1
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 2
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 3
return
seq_all
;
// clang-format on
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
View file @
45131629
...
@@ -43,5 +43,6 @@ enum class FusedMoeGemmPipelineSequencerEnum
...
@@ -43,5 +43,6 @@ enum class FusedMoeGemmPipelineSequencerEnum
GLD_B
=
1
<<
3
,
GLD_B
=
1
<<
3
,
SST_A
=
1
<<
4
,
// shared store a
SST_A
=
1
<<
4
,
// shared store a
SST_B
=
1
<<
5
,
SST_B
=
1
<<
5
,
GST_O
=
1
<<
6
,
// global store out
};
};
}
// namespace ck_tile
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment