Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
84755f74
"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "3986e2d2cfadd43d9bb5fbac5ef711f902c06831"
Commit
84755f74
authored
Nov 16, 2024
by
“letaoqin”
Browse files
format
parent
eab497e8
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
486 additions
and
98 deletions
+486
-98
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
...tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
+1
-1
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp
...used_moe_general/instances/fused_moegemm_api_internal.hpp
+1
-1
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp
..._fused_moe_general/instances/fused_moegemm_api_traits.hpp
+1
-1
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
...16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
+1
-1
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+9
-5
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+1
-1
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+23
-21
include/ck_tile/ops/flatmm/pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp
.../pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp
+20
-24
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+1
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+382
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
+6
-3
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+25
-24
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+15
-16
No files found.
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
View file @
84755f74
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
}
// clang-format on
// clang-format on
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_internal.hpp
View file @
84755f74
...
@@ -40,7 +40,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
...
@@ -40,7 +40,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmGl
<
f_problem
>
;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_FlatmmGl
<
f_problem
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemm
Gl
Kernel
<
f_partitioner
,
f_pipeline
,
void
>
;
const
dim3
grids
=
f_kernel
::
GridSize
(
a
);
const
dim3
grids
=
f_kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api_traits.hpp
View file @
84755f74
...
@@ -48,7 +48,7 @@ struct fmoe_ // traits, ugly name, only used for internal
...
@@ -48,7 +48,7 @@ struct fmoe_ // traits, ugly name, only used for internal
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
;
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
View file @
84755f74
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
// clang-format off
// clang-format off
template
float
fused_moegemm_
<
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
// clang-format on
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
84755f74
...
@@ -432,14 +432,18 @@ struct tile_window_linear
...
@@ -432,14 +432,18 @@ struct tile_window_linear
CK_TILE_DEVICE
static
constexpr
index_t
get_bottom_linear_offset
(
number
<
i_access
>
)
CK_TILE_DEVICE
static
constexpr
index_t
get_bottom_linear_offset
(
number
<
i_access
>
)
{
{
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
number
<
i_access
>
{});
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
number
<
i_access
>
{});
constexpr
auto
is_pure_linear_tensor
=
reduce_on_sequence
(
LinearBottomDims
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
is_pure_linear_tensor
=
if
constexpr
(
is_pure_linear_tensor
)
{
reduce_on_sequence
(
LinearBottomDims
{},
multiplies
{},
number
<
1
>
{});
if
constexpr
(
is_pure_linear_tensor
)
{
// this case usually is a LDS window, everything is build time know.
// this case usually is a LDS window, everything is build time know.
// we directly use BottomTensorView to compute the offset, in case there is any padding
// we directly use BottomTensorView to compute the offset, in case there is any padding
auto
bottom_tensor_coord
=
make_tensor_coordinate
(
auto
bottom_tensor_coord
=
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
make_tensor_coordinate
(
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
return
bottom_tensor_coord
.
get_offset
();
return
bottom_tensor_coord
.
get_offset
();
}
else
{
}
else
{
// this case usually is a global window, where last dim can be linear
// this case usually is a global window, where last dim can be linear
// we hack here, that use the original TileDstr to compute the linear offset
// we hack here, that use the original TileDstr to compute the linear offset
// ... hoping that there is no extra padding between other dims, which make sense
// ... hoping that there is no extra padding between other dims, which make sense
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
84755f74
...
@@ -135,7 +135,7 @@ void reference_fused_moe(
...
@@ -135,7 +135,7 @@ void reference_fused_moe(
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
{
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
//printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
//
printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
}
}
else
else
...
...
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
View file @
84755f74
...
@@ -620,8 +620,8 @@ struct FastGeluAsm
...
@@ -620,8 +620,8 @@ struct FastGeluAsm
CK_TILE_HOST
void
operator
()
<
fp32x2_t
,
fp32x2_t
>
(
fp32x2_t
&
y
,
const
fp32x2_t
&
x
)
const
CK_TILE_HOST
void
operator
()
<
fp32x2_t
,
fp32x2_t
>
(
fp32x2_t
&
y
,
const
fp32x2_t
&
x
)
const
{
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u0
=
x
.
x
*
(
c1
*
x
.
x
*
x
.
x
+
c2
);
const
float
u0
=
x
.
x
*
(
c1
*
x
.
x
*
x
.
x
+
c2
);
const
float
emu0
=
exp
(
u0
);
const
float
emu0
=
exp
(
u0
);
y
.
x
=
x
.
x
/
(
1.
f
+
emu0
);
y
.
x
=
x
.
x
/
(
1.
f
+
emu0
);
...
@@ -641,25 +641,27 @@ struct FastGeluAsm
...
@@ -641,25 +641,27 @@ struct FastGeluAsm
float
tmp0
,
tmp1
;
float
tmp0
,
tmp1
;
float
y0
,
y1
;
float
y0
,
y1
;
asm
volatile
(
"v_mul_f32 %[v_tmp0], %[v_x0], %[v_x0] ; x*x
\n
"
asm
volatile
(
"v_mul_f32 %[v_tmp1], %[v_x1], %[v_x1] ; x*x
\n
"
"v_mul_f32 %[v_tmp0], %[v_x0], %[v_x0] ; x*x
\n
"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_mul_f32 %[v_tmp1], %[v_x1], %[v_x1] ; x*x
\n
"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_x0] ; x*(c1*x*x+c2)
\n
"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_x1] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_x0] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_x1] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f
\n
"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f
\n
"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f
\n
"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)
\n
"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f
\n
"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)
\n
"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)
\n
"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_x0] ; x * 1/(emu+1f)
\n
"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)
\n
"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_x1] ; x * 1/(emu+1f)
\n
"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_x0] ; x * 1/(emu+1f)
\n
"
:
[
v_y0
]
"=v"
(
y0
),
[
v_y1
]
"=v"
(
y1
),
[
v_tmp0
]
"+v"
(
tmp0
),
[
v_tmp1
]
"+v"
(
tmp1
)
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_x1] ; x * 1/(emu+1f)
\n
"
:
[
v_x0
]
"v"
(
x
.
x
),
[
v_x1
]
"v"
(
x
.
y
),
[
s_c1
]
"s"
(
c1
),
[
v_c2
]
"v"
(
c2
),
[
s_log2e
]
"s"
(
log2e_
)
:
[
v_y0
]
"=v"
(
y0
),
[
v_y1
]
"=v"
(
y1
),
[
v_tmp0
]
"+v"
(
tmp0
),
[
v_tmp1
]
"+v"
(
tmp1
)
:
);
:
[
v_x0
]
"v"
(
x
.
x
),
[
v_x1
]
"v"
(
x
.
y
),
[
s_c1
]
"s"
(
c1
),
[
v_c2
]
"v"
(
c2
),
[
s_log2e
]
"s"
(
log2e_
)
:
);
y
.
x
=
y0
;
y
.
x
=
y0
;
y
.
y
=
y1
;
y
.
y
=
y1
;
}
}
...
...
include/ck_tile/ops/flatmm/pipeline/uk/flatmm_uk_gfx9_32x512x128_1x4x1_16x16x16.hpp
View file @
84755f74
...
@@ -72,7 +72,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
...
@@ -72,7 +72,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
sequence
<
2
,
1
>
,
// !! note here is different
sequence
<
2
,
1
>
,
// !! note here is different
sequence
<
0
,
0
>>
{};
sequence
<
0
,
0
>>
{};
using
WG
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
using
WG
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
...
@@ -82,7 +82,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
...
@@ -82,7 +82,7 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
{
{
using
CDataType
=
float
;
using
CDataType
=
float
;
constexpr
auto
c_block_dstr
=
MakeCBlockDist
();
constexpr
auto
c_block_dstr
=
MakeCBlockDist
();
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
return
c_block_tensor
;
...
@@ -180,8 +180,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
...
@@ -180,8 +180,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
{
// load from LDS to register, every wave has same layout
// load from LDS to register, every wave has same layout
constexpr
index_t
KPack_
=
8
;
// GetSmemKPack_A<Problem>(); // LDS
constexpr
index_t
KPack_
=
8
;
// GetSmemKPack_A<Problem>(); // LDS
constexpr
index_t
KPad
=
KPack_
;
// pad between warps
constexpr
index_t
KPad
=
KPack_
;
// pad between warps
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKLane
=
4
;
...
@@ -189,26 +189,25 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
...
@@ -189,26 +189,25 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
constexpr
index_t
kKIter
=
2
;
constexpr
index_t
kKIter
=
2
;
static_assert
(
KPack_
==
(
kABKPerLane
*
kKIter
));
static_assert
(
KPack_
==
(
kABKPerLane
*
kKIter
));
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
lds_block_desc_0
=
make_tuple
(
number
<
Repeat_M
>
{},
// m0 y
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Repeat_M
>
{},
// m0 y
number
<
kAMLane
>
{},
// m1 p
number
<
kAMLane
>
{},
// m1 p
number
<
Repeat_K
>
{},
// k0 y
number
<
Repeat_K
>
{},
// k0 y
number
<
kABKLane
>
{},
// k1 p
number
<
kABKLane
>
{},
// k1 p
number
<
KPack_
>
{}),
// k2 y-vector
number
<
KPack_
>
{}),
// k2 y-vector
make_tuple
(
number
<
kAMLane
*
(
Block_K
+
KPad
)
>
{},
// m0
make_tuple
(
number
<
kAMLane
*
(
Block_K
+
KPad
)
>
{},
// m0
number
<
Block_K
+
KPad
>
{},
// m1
number
<
Block_K
+
KPad
>
{},
// m1
number
<
kABKLane
*
KPack_
>
{},
// k0
number
<
kABKLane
*
KPack_
>
{},
// k0
number
<
KPack_
>
{},
// k1
number
<
KPack_
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
1
>
{}),
// k2
number
<
KPack_
>
{},
// lds load vector
number
<
KPack_
>
{},
// lds load vector
number
<
1
>
{});
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_merge_transform
(
make_merge_transform
(
make_tuple
(
make_tuple
(
number
<
Repeat_K
>
{},
number
<
kABKLane
>
{},
number
<
KPack_
>
{}))),
number
<
Repeat_K
>
{},
number
<
kABKLane
>
{},
number
<
KPack_
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -291,12 +290,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
...
@@ -291,12 +290,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
},
},
number
<
a_sld
.
get_num_of_access
()
>
{});
number
<
a_sld
.
get_num_of_access
()
>
{});
// printf("----- tid:%d, a_sld:%d\n", static_cast<index_t>(threadIdx.x),
// printf("----- tid:%d, a_sld:%d\n", static_cast<index_t>(threadIdx.x),
// static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset()));
// static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset()));
index_t
loop_cnt
=
k
/
Block_K
;
index_t
loop_cnt
=
k
/
Block_K
;
// this is the acc thread buffer
// this is the acc thread buffer
...
...
include/ck_tile/ops/fused_moe.hpp
View file @
84755f74
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
0 → 100644
View file @
84755f74
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace
ck_tile
{
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
// struct FusedMoeGemmHostArgs
// {
// const void* a_ptr; // [m, k], input token
// const void* a_scale_ptr; // [m, 1], token scale
// const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
// const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
// const void* g_scale_ptr; // [e, 1, n], gate(up) scale
// const void* d_scale_ptr; // [e, 1, k], down scale
// const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
// void* o_ptr; // [m, k], output token
// const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
// const void* sorted_weight_ptr; // [max_num_tokens_padded]
// const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
// const void* num_sorted_tiles_ptr; // [1]
// index_t hidden_size; // k
// index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
// index_t num_tokens; // input number of tokens for current iteration
// index_t num_experts; // number of groups
// index_t topk; // need this?
// index_t stride_token; // for input/output, stride for each row, should >= hidden_size
// };
// This is scatter/gather b2b group-gemm
template
<
typename
Partitioner_
,
typename
Pipeline_
,
typename
Epilogue_
>
struct
FusedMoeGemmGlKernel
{
using
Partitioner
=
remove_cvref_t
<
Partitioner_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
static
constexpr
index_t
BlockSize_
=
BlockShape
::
BlockSize
;
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
using
DDataType
=
typename
Pipeline
::
Problem
::
DDataType
;
using
AccDataType
=
typename
Pipeline
::
Problem
::
AccDataType
;
using
ODataType
=
typename
Pipeline
::
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Pipeline
::
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Pipeline
::
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Pipeline
::
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Pipeline
::
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Pipeline
::
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Pipeline
::
Problem
::
IndexDataType
;
using
YDataType
=
typename
Pipeline
::
Problem
::
YDataType
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
S_
=
BlockShape
;
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
ADataType
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType
,
GDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
GDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"fused_moe_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M0
)
+
"x"
+
_TS_
(
S_
::
Block_N0
)
+
"x"
+
_TS_
(
S_
::
Block_K0
)
+
"x"
+
_TS_
(
S_
::
Block_N1
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_K0
)
+
"_"
+
_TS_
(
S_
::
Warp_M0
)
+
"x"
+
_TS_
(
S_
::
Warp_N0
)
+
"x"
+
_TS_
(
S_
::
Warp_K0
)
+
"_"
+
_SS_
(
Pipeline
::
name
);
#undef _SS_
#undef _TS_
// clang-format on
}
struct
FusedMoeGemmKargs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using
Kargs
=
FusedMoeGemmKargs
;
using
Hargs
=
FusedMoeGemmHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
// TODO: hargs/kargs not guranteed to be the same
return
bit_cast
<
Kargs
>
(
hargs
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
constexpr
index_t
block_m
=
BlockShape
::
Block_M0
;
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
nr_0
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index(i_m and i_n)
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t
interm_idx_nr
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
const
auto
a_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
a_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentA
>
{},
number
<
1
>
{});
// gather is here use indexing transform
const
auto
a_gather_view_
=
transform_tensor_view
(
a_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
a_window_
=
make_tile_window
(
a_gather_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
return
a_window_
;
}();
// TODO: gtile using NSub to have less register pressure
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
});
return
g_window_
;
}();
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_nr
*
BlockShape
::
Block_W1
;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
});
return
d_window_
;
}();
auto
o_window
=
[
&
]()
{
ODataType
*
o_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_ptr
);
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
// gather is here
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
o_window_
=
make_tile_window
(
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
{
0
,
0
});
return
o_window_
;
}();
// do compute yeah
Pipeline
{}(
a_window
,
g_window
,
d_window
,
o_window
,
topk_weight
,
smem
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
,
kargs
.
stride_token
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
View file @
84755f74
...
@@ -70,12 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmGl
...
@@ -70,12 +70,16 @@ struct FusedMoeGemmPipeline_FlatmmGl
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
// matrix a or tokens smem
constexpr
index_t
smem_mat_a
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_K0
*
sizeof
(
ADataType
);
// shuffle C matrix
constexpr
index_t
smem_bridge
=
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
smem_bridge
;
return
max
(
smem_mat_a
,
smem_bridge
);
}
}
template
<
typename
Karg
>
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_LDS_ADDR
void
*
smem
,
CK_TILE_LDS_ADDR
void
*
smem
,
...
@@ -86,7 +90,6 @@ struct FusedMoeGemmPipeline_FlatmmGl
...
@@ -86,7 +90,6 @@ struct FusedMoeGemmPipeline_FlatmmGl
ignore
=
smem
;
ignore
=
smem
;
ignore
=
sorted_tile_id
;
ignore
=
sorted_tile_id
;
ignore
=
intermediate_tile_id
;
ignore
=
intermediate_tile_id
;
}
}
};
};
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
View file @
84755f74
...
@@ -590,39 +590,40 @@ struct FusedMoeGemmPipelineFlatmmPolicy
...
@@ -590,39 +590,40 @@ struct FusedMoeGemmPipelineFlatmmPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreForUKDesc
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreForUKDesc
()
{
{
constexpr
index_t
WarpPerBlock_N
=
Problem
::
BlockShape
::
WarpPerBlock_N0
;
constexpr
index_t
WarpPerBlock_N
=
Problem
::
BlockShape
::
WarpPerBlock_N0
;
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N0
;
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N0
;
constexpr
index_t
Repeat_M
=
Problem
::
BlockShape
::
Repeat_M0
;
constexpr
index_t
Repeat_M
=
Problem
::
BlockShape
::
Repeat_M0
;
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
KPack
=
kABKPerLane
;
constexpr
index_t
KPack
=
kABKPerLane
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Repeat_M
>
{},
// m
make_tuple
(
number
<
Repeat_M
>
{},
// m
number
<
Repeat_N
>
{},
// n
number
<
Repeat_N
>
{},
// n
number
<
WarpPerBlock_N
>
{},
// n
number
<
WarpPerBlock_N
>
{},
// n
number
<
kABKLane
>
{},
// n
number
<
kABKLane
>
{},
// n
number
<
kAMLane
>
{},
// m
number
<
kAMLane
>
{},
// m
number
<
KPack
>
{}),
// n
number
<
KPack
>
{}),
// n
make_tuple
(
number
<
Repeat_N
*
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// m
make_tuple
(
number
<
Repeat_N
*
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// m
number
<
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
kAMLane
*
KPack
>
{},
// n
number
<
kAMLane
*
KPack
>
{},
// n
number
<
KPack
>
{},
// m
number
<
KPack
>
{},
// m
number
<
1
>
{}),
// n
number
<
1
>
{}),
// n
number
<
KPack
>
{},
// lds store vector(actually no explicit store)
number
<
KPack
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
number
<
1
>
{});
constexpr
auto
desc
=
transform_tensor_descriptor
(
constexpr
auto
desc
=
transform_tensor_descriptor
(
lds_block_desc_0
,
lds_block_desc_0
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_merge_transform
(
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_merge_transform
(
make_tuple
(
number
<
Repeat_N
>
{},
make_merge_transform
(
make_tuple
(
number
<
Repeat_N
>
{},
number
<
WarpPerBlock_N
>
{},
number
<
kABKLane
>
{},
number
<
KPack
>
{}))
number
<
WarpPerBlock_N
>
{},
),
number
<
kABKLane
>
{},
make_tuple
(
sequence
<
0
,
4
>
{},
sequence
<
1
,
2
,
3
,
5
>
{}),
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
,
4
>
{},
sequence
<
1
,
2
,
3
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
desc
;
return
desc
;
}
}
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
View file @
84755f74
...
@@ -342,13 +342,11 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -342,13 +342,11 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto
bridge_sst_win
=
[
&
]()
{
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
return
make_tile_window_linear
(
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
.
get_lengths
(),
desc_
),
{
0
,
0
},
desc_
.
get_lengths
(),
dist_
);
{
0
,
0
},
dist_
);
}();
}();
auto
o_res
=
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
...
@@ -442,16 +440,17 @@ struct FusedMoeGemmPipeline_FlatmmUk
...
@@ -442,16 +440,17 @@ struct FusedMoeGemmPipeline_FlatmmUk
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
// return ;
// return ;
//sweep_tile(acc_0,
//
sweep_tile(acc_0,
// [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
// [&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
sweep_tile
(
acc_0
,
sweep_tile
(
[
&
](
auto
idx0
,
auto
idx1
)
{
acc_0
,
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
[
&
](
auto
idx0
,
auto
idx1
)
{
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
acc_0
(
idx0
)
=
v_
.
x
;
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
acc_0
(
idx1
)
=
v_
.
y
;
acc_0
(
idx0
)
=
v_
.
x
;
},
acc_0
(
idx1
)
=
v_
.
y
;
sequence
<
1
,
2
>
{});
},
sequence
<
1
,
2
>
{});
#if 0
#if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment