Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2baf9422
"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "ddfb38efa733f52ced8d02b03c9fd913e5d7e044"
Commit
2baf9422
authored
Dec 05, 2024
by
letaoqin
Browse files
add moe general
parent
c918cd4f
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
92 additions
and
18 deletions
+92
-18
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api.cpp
...tile/17_fused_moe_general/instances/fused_moegemm_api.cpp
+2
-2
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_internal.hpp
...used_moe_general/instances/fused_moegemm_api_internal.hpp
+4
-4
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp
..._fused_moe_general/instances/fused_moegemm_api_traits.hpp
+2
-2
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
...17_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
+1
-1
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_fp16_m32.cpp
...17_fused_moe_general/instances/fused_moegemm_fp16_m32.cpp
+1
-1
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+6
-5
example/ck_tile/17_fused_moe_general/misc/moe-0.png
example/ck_tile/17_fused_moe_general/misc/moe-0.png
+0
-0
example/ck_tile/17_fused_moe_general/misc/moe-1.png
example/ck_tile/17_fused_moe_general/misc/moe-1.png
+0
-0
example/ck_tile/17_fused_moe_general/misc/moe-2.png
example/ck_tile/17_fused_moe_general/misc/moe-2.png
+0
-0
example/ck_tile/17_fused_moe_general/misc/moe-3.png
example/ck_tile/17_fused_moe_general/misc/moe-3.png
+0
-0
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+72
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+3
-3
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+1
-0
No files found.
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api.cpp
View file @
2baf9422
...
...
@@ -19,13 +19,13 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
5
12
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
12
8
,
32
,
32
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
else
if
(
t
.
prec_i
==
"fp16"
&&
t
.
prec_w
==
"fp16"
&&
t
.
prec_o
==
"fp16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
5
12
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
12
8
,
32
,
32
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
...
...
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_internal.hpp
View file @
2baf9422
...
...
@@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename
Ts_
::
WarpPerBlock_0
,
typename
Ts_
::
WarpTile_0
,
typename
Ts_
::
BlockTile_1
,
typename
Ts_
::
WarpPerBlock_
0
,
typename
Ts_
::
WarpTile_
0
>
;
typename
Ts_
::
WarpPerBlock_
1
,
typename
Ts_
::
WarpTile_
1
>
;
using
f_problem
=
ck_tile
::
FusedMoeGemmPipelineProblem
<
typename
Ts_
::
ADataType
,
typename
Ts_
::
GDataType
,
...
...
@@ -38,9 +38,9 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits
>
;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_
FlatmmUk
<
f_problem
>
;
using
f_pipeline
=
ck_tile
::
FusedMoeGemmPipeline_
General
<
f_problem
>
;
using
f_partitioner
=
ck_tile
::
FusedMoeGemmTilePartitioner_Linear
<
f_shape
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemmKernel
<
f_partitioner
,
f_pipeline
,
void
>
;
using
f_kernel
=
ck_tile
::
FusedMoeGemm
Gl
Kernel
<
f_partitioner
,
f_pipeline
,
void
>
;
const
dim3
grids
=
f_kernel
::
GridSize
(
a
);
constexpr
dim3
blocks
=
f_kernel
::
BlockSize
();
...
...
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_api_traits.hpp
View file @
2baf9422
...
...
@@ -44,8 +44,8 @@ struct fmoe_ // traits, ugly name, only used for internal
using
WarpPerBlock_0
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
WarpTile_0
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
/
(
GateOnly_
?
1
:
2
)
>
;
using
WarpPerBlock_1
=
ck_tile
::
remove_cvref_t
<
WarpPerBlock_
>
;
using
BlockTile_1
=
ck_tile
::
sequence
<
BT_
,
BD_
,
BI_
>
;
using
WarpPerBlock_1
=
ck_tile
::
sequence
<
1
,
1
,
4
>
;
//
ck_tile::remove_cvref_t<WarpPerBlock_>;
using
WarpTile_1
=
ck_tile
::
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
ck_tile
::
index_t
GateOnly
=
GateOnly_
;
...
...
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
View file @
2baf9422
...
...
@@ -8,7 +8,7 @@
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
5
12
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
12
8
,
32
,
32
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/17_fused_moe_general/instances/fused_moegemm_fp16_m32.cpp
View file @
2baf9422
...
...
@@ -8,7 +8,7 @@
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
5
12
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
16
,
16
,
32
>
,
1
,
0
>
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
12
8
,
32
,
32
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/17_fused_moe_general/main.cpp
View file @
2baf9422
...
...
@@ -261,8 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
//
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
//
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if
(
balance
)
...
...
@@ -345,8 +345,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_
perm_
host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_
perm_
host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_host
);
ck_tile
::
DeviceMem
d_perm_buf
(
d_host
);
ck_tile
::
DeviceMem
sa_buf
(
sa_host
);
ck_tile
::
DeviceMem
sg_buf
(
sg_host
);
ck_tile
::
DeviceMem
sd_buf
(
sd_host
);
...
...
@@ -390,7 +390,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens
,
experts
,
topk
,
stride
};
stride
,
max_num_tokens_padded
};
float
ave_time
=
fused_moegemm
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
kname
?
1
:
0
,
warmup
,
repeat
});
...
...
example/ck_tile/17_fused_moe_general/misc/moe-0.png
deleted
100644 → 0
View file @
c918cd4f
75 KB
example/ck_tile/17_fused_moe_general/misc/moe-1.png
deleted
100644 → 0
View file @
c918cd4f
90.4 KB
example/ck_tile/17_fused_moe_general/misc/moe-2.png
deleted
100644 → 0
View file @
c918cd4f
124 KB
example/ck_tile/17_fused_moe_general/misc/moe-3.png
deleted
100644 → 0
View file @
c918cd4f
18.2 KB
include/ck_tile/core/algorithm/indexing_adaptor.hpp
View file @
2baf9422
...
...
@@ -57,4 +57,76 @@ struct indexing_adaptor_onshot_cached
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
};
#define Using_Gather 1
template
<
typename
IndexingType
>
struct
indexing_adaptor
{
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
(
const
IndexingType
*
idx
)
:
cached_idx_
(
idx
)
{}
const
IndexingType
*
cached_idx_
;
#if Using_Gather
mutable
index_t
pre_up_index_
=
0
;
mutable
index_t
pre_low_index_
=
0
;
#endif
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
*
(
cached_idx_
+
idx_up
[
number
<
0
>
{}]);
#if Using_Gather
pre_up_index_
=
idx_up
[
number
<
0
>
{}];
pre_low_index_
=
idx_low
(
number
<
0
>
{});
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
}
#endif
#endif
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
/*idx_low*/
,
const
UpIdx
&
/*idx_up*/
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
#if !Using_Gather
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
#else
int
up_index
=
idx_diff_up
[
number
<
0
>
{}]
+
pre_up_index_
;
int
low_index
=
*
(
cached_idx_
+
up_index
);
idx_diff_low
(
number
<
0
>
{})
=
low_index
-
pre_low_index_
;
pre_up_index_
=
up_index
;
pre_low_index_
=
low_index
;
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n index form %d to %d, diff from %d to %d \n",
up_index,
low_index,
idx_diff_up[number<0>{}],
idx_diff_low(number<0>{}));
}
#endif
#endif
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
2baf9422
...
...
@@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
constexpr
index_t
block_m
=
BlockShape
::
Block_M0
;
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
//
constexpr index_t block_m = BlockShape::Block_M0;
int
max_num_tokens_padded
=
hargs
.
max_num_tokens_padded
;
//
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
2baf9422
...
...
@@ -117,6 +117,7 @@ struct FusedMoeGemmHostArgs
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
index_t
max_num_tokens_padded
;
// size of sorted_token_ids_ptr
};
// This is scatter/gather b2b group-gemm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment