Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
593dd7ad
Commit
593dd7ad
authored
Dec 04, 2024
by
letaoqin
Browse files
clear some code
parent
6cb91035
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
17 deletions
+19
-17
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
...tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
+6
-0
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
...16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
+4
-1
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+8
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+1
-16
No files found.
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
View file @
593dd7ad
...
...
@@ -22,6 +22,12 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
using
t_
=
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
32
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
// t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
// {
// using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
// r = fused_moegemm_<t_>(s, a);
// }
// clang-format on
return
r
;
}
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
View file @
593dd7ad
...
...
@@ -8,7 +8,10 @@
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
ck_tile
::
f
p
16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
32
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
fmoe_
<
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
ck_tile
::
b
f16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
32
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
32
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
example/ck_tile/16_fused_moe_general/main.cpp
View file @
593dd7ad
...
...
@@ -429,6 +429,14 @@ int main(int argc, char* argv[])
prec_kw
=
(
prec_kw
==
"auto"
)
?
"fp32"
:
prec_kw
;
// no dynamic quant case
// if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
// {
// return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
// arg_parser)
// ? 0
// : -2;
// }
// else
if
(
prec_i
==
"bf16"
&&
prec_w
==
"bf16"
&&
prec_o
==
"bf16"
&&
prec_kw
==
"fp32"
)
{
return
run
<
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
ck_tile
::
fp16_t
,
float
,
float
,
float
,
float
>
(
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
593dd7ad
...
...
@@ -193,10 +193,6 @@ struct FusedMoeGemmPipeline_General
}
// relu
const
auto
activation
=
ck_tile
::
element_wise
::
Gelu
{};
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
// });
tile_elementwise_inout
(
activation
,
s_acc
,
s_acc
);
#if 0
PrintMem(s_acc);
...
...
@@ -210,18 +206,7 @@ struct FusedMoeGemmPipeline_General
{
0
,
0
});
// cast data to YDataType
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
// //y_pre.get_thread_buffer()(i) = type_convert<YDataType>(s_acc.get_thread_buffer()[i]);
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// printf("soure value: %f to value: %f\n",
// s_acc.get_thread_buffer()[i],
// type_convert<float>(y_pre.get_thread_buffer()[i]));
// }
// });
#if 1
#if 0
PrintMem(y_pre);
#endif
// save to lds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment