Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d0405504
Commit
d0405504
authored
Nov 11, 2024
by
carlushuang
Browse files
update
parent
9d3cdd21
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
73 additions
and
43 deletions
+73
-43
example/ck_tile/15_fused_moe/main.cpp
example/ck_tile/15_fused_moe/main.cpp
+48
-39
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+16
-0
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+2
-1
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+4
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+3
-3
No files found.
example/ck_tile/15_fused_moe/main.cpp
View file @
d0405504
...
...
@@ -105,9 +105,9 @@ auto create_args(int argc, char* argv[])
.
insert
(
"prec_kw"
,
"auto"
,
"topk-weight data type. auto will set to fp32"
)
.
insert
(
"fquant"
,
"0"
,
"fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant"
)
.
insert
(
"gate_only"
,
"
0
"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
"gate_only"
,
"
1
"
,
"w0(gate/up) style, 0:gate+up will double interm size, 1:only gate"
)
.
insert
(
"balance"
,
"
1
"
,
"
0
"
,
"if set to 1, will try balance the expert in topk-ids(convenient for testing)"
)
.
insert
(
"warmup"
,
"5"
,
"cold iter"
)
.
insert
(
"repeat"
,
"20"
,
"hot iter"
);
...
...
@@ -121,6 +121,8 @@ auto create_args(int argc, char* argv[])
template
<
typename
I
,
typename
W
,
typename
O
,
typename
ST
,
typename
SW
,
typename
SQ
,
typename
KW
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
std
::
cout
<<
"xxxx "
<<
__LINE__
<<
std
::
flush
<<
std
::
endl
;
ck_tile
::
index_t
tokens
=
arg_parser
.
get_int
(
"t"
);
ck_tile
::
index_t
experts
=
arg_parser
.
get_int
(
"e"
);
ck_tile
::
index_t
topk
=
arg_parser
.
get_int
(
"k"
);
...
...
@@ -150,7 +152,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
int
balance
=
arg_parser
.
get_int
(
"balance"
);
int
tp
=
arg_parser
.
get_int
(
"tp"
);
ck_tile
::
index_t
shared_intermediate_size
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
// w0 (Gate+Up or Gate only, N size)
ck_tile
::
index_t
shared_intermediate_size_0
=
intermediate_size
*
(
gate_only
?
1
:
2
)
/
tp
;
// w1 (Down, N size)
ck_tile
::
index_t
shared_intermediate_size_1
=
intermediate_size
/
tp
;
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_w
)
base_str
+=
"x"
+
prec_w
;
if
(
prec_i
!=
prec_o
)
base_str
+=
"="
+
prec_o
;
if
(
fused_quant
!=
0
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_st
+
"|"
+
prec_sw
+
"|"
+
prec_sq
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
", st:"
<<
stride
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
using
TypeConfig
=
FusedMoeGemmTypeConfig
<
I
,
W
,
O
,
ST
,
SW
,
SQ
,
KW
>
;
using
ADataType
=
typename
TypeConfig
::
ADataType
;
...
...
@@ -167,36 +190,36 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
ck_tile
::
HostTensor
<
ADataType
>
a_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
intermediate_size
,
hidden_size
});
ck_tile
::
HostTensor
<
GDataType
>
g_host
({
experts
,
shared_intermediate_size
_0
,
hidden_size
});
ck_tile
::
HostTensor
<
DDataType
>
d_host
({
experts
,
shared_
intermediate_size
_1
,
hidden_size
});
ck_tile
::
HostTensor
<
ODataType
>
o_host
({
tokens
,
hidden_size
},
{
stride
,
1
});
ck_tile
::
HostTensor
<
AScaleDataType
>
sa_host
({
tokens
});
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
intermediate_size
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
intermediate_size
});
// smooth-quant
ck_tile
::
HostTensor
<
GScaleDataType
>
sg_host
({
shared_intermediate_size
_0
});
ck_tile
::
HostTensor
<
DScaleDataType
>
sd_host
({
shared_
intermediate_size
_1
});
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>
sy_host
({
shared_
intermediate_size
_1
});
// smooth-quant
ck_tile
::
HostTensor
<
IndexDataType
>
topk_ids_host
({
tokens
,
topk
});
// to be sort
ck_tile
::
HostTensor
<
TopkWeightDataType
>
topk_weight_host
({
tokens
,
topk
});
// to be sort
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
(
block_m
-
1
)
;
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_token_ids_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
TopkWeightDataType
>
sorted_weight_host
({
max_num_tokens_padded
});
ck_tile
::
HostTensor
<
IndexDataType
>
sorted_expert_ids_host
(
{(
max_num_tokens_padded
+
block_m
-
1
)
/
block_m
});
ck_tile
::
HostTensor
<
IndexDataType
>
num_sorted_tiles_host
({
1
});
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
);
ck_tile
::
FillUniformDistribution
<
ADataType
>
{
-
.5
f
,
.5
f
}(
a_host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
}(
g_
perm_
host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
}(
d_
perm_
host
);
ck_tile
::
FillUniformDistribution
<
GDataType
>
{
-
.5
f
,
.5
f
}(
g_host
);
ck_tile
::
FillUniformDistribution
<
DDataType
>
{
-
.5
f
,
.5
f
}(
d_host
);
ck_tile
::
FillUniformDistribution
<
AScaleDataType
>
{
-
.5
f
,
.5
f
}(
sa_host
);
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
}(
sy_host
);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
-
.5
f
,
.5
f
}(
topk_weight_host
);
// permute weight
ck_tile
::
HostTensor
<
GDataType
>
g_perm_host
=
shuffle_moe_weight
(
g_host
,
prec_w
,
1
);
ck_tile
::
HostTensor
<
DDataType
>
d_perm_host
=
shuffle_moe_weight
(
d_host
,
prec_w
,
1
);
// do moe sorting
if
(
balance
)
{
...
...
@@ -223,6 +246,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// done, preparing GPU buffer
ck_tile
::
DeviceMem
a_buf
(
a_host
);
ck_tile
::
DeviceMem
g_perm_buf
(
g_perm_host
);
...
...
@@ -238,24 +265,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
sorted_expert_ids_buf
(
sorted_expert_ids_host
);
ck_tile
::
DeviceMem
num_sorted_tiles_buf
(
num_sorted_tiles_host
);
auto
prec_str
=
[
&
]()
{
auto
base_str
=
prec_i
;
if
(
prec_i
!=
prec_w
)
base_str
+=
"x"
+
prec_w
;
if
(
prec_i
!=
prec_o
)
base_str
+=
"="
+
prec_o
;
if
(
fused_quant
!=
0
)
{
base_str
+=
std
::
string
(
"("
)
+
prec_st
+
"|"
+
prec_sw
+
"|"
+
prec_sq
+
")"
;
}
return
base_str
;
}();
std
::
cout
<<
"["
<<
prec_str
<<
"]"
<<
" t:"
<<
tokens
<<
", e:"
<<
experts
<<
", k:"
<<
topk
<<
", st:"
<<
stride
<<
", hidden:"
<<
hidden_size
<<
", interm:"
<<
intermediate_size
<<
", tp:"
<<
tp
<<
", go:"
<<
gate_only
<<
", q:"
<<
fused_quant
<<
std
::
flush
;
fused_moegemm_traits
traits
{
prec_i
,
prec_w
,
prec_o
,
...
...
@@ -280,7 +289,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf
.
GetDeviceBuffer
(),
num_sorted_tiles_buf
.
GetDeviceBuffer
(),
hidden_size
,
intermediate_size
,
shared_
intermediate_size
_0
,
tokens
,
experts
,
topk
,
...
...
@@ -302,8 +311,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
#else
std
::
size_t
flop_gemm_0
=
2
*
tokens
*
topk
*
shared_intermediate_size
*
hidden_size
;
std
::
size_t
flop_gemm_1
=
2
*
tokens
*
topk
*
hidden
_size
*
hidden_size
;
std
::
size_t
flop_gemm_0
=
2
*
tokens
*
topk
*
shared_intermediate_size
_0
*
hidden_size
;
std
::
size_t
flop_gemm_1
=
2
*
tokens
*
topk
*
shared_intermediate
_size
_1
*
hidden_size
;
double
tflops
=
(
flop_gemm_0
+
flop_gemm_1
)
/
(
static_cast
<
double
>
(
ave_time
)
*
1e-3
)
/
1e12
;
// float gb_per_sec = num_byte / 1.E6 / ave_time;
...
...
@@ -331,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens
,
experts
,
hidden_size
,
intermediate_size
,
shared_
intermediate_size
_0
,
topk
,
gate_only
);
...
...
include/ck_tile/host/host_tensor.hpp
View file @
d0405504
...
...
@@ -590,6 +590,22 @@ struct HostTensor
size
()
*
FromSize
/
ToSize
};
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensor
<
T
>&
t
)
{
os
<<
t
.
mDesc
;
os
<<
"["
;
for
(
typename
Data
::
size_type
idx
=
0
;
idx
<
t
.
mData
.
size
();
++
idx
)
{
if
(
0
<
idx
)
{
os
<<
", "
;
}
os
<<
t
.
mData
[
idx
];
}
os
<<
"]"
;
return
os
;
}
Descriptor
mDesc
;
Data
mData
;
};
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
d0405504
...
...
@@ -21,6 +21,7 @@ namespace ck_tile {
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
...
...
@@ -94,7 +95,7 @@ void reference_fused_moe(
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
(
block_m
-
1
)
;
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
...
...
include/ck_tile/ops/fused_moe.hpp
View file @
d0405504
...
...
@@ -6,10 +6,14 @@
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
View file @
d0405504
...
...
@@ -99,7 +99,7 @@ struct FusedMoeGemmHostArgs
const
void
*
num_sorted_tiles_ptr
;
// [1]
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n
(TP slice this)
index_t
intermediate_size
;
// n
/ TP, for Gate. if Gate+Up, Down need divide by 2
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
...
...
@@ -176,7 +176,7 @@ struct FusedMoeGemmKernel
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n
(TP slice this)
index_t
intermediate_size
;
// n
/ TP, for Gate. if Gate+Up, Down need divide by 2
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
...
...
@@ -198,7 +198,7 @@ struct FusedMoeGemmKernel
{
constexpr
index_t
block_m
=
BlockShape
::
Block_M0
;
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
(
block_m
-
1
)
;
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment