Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
40182e94
Commit
40182e94
authored
Jan 14, 2025
by
letaoqin
Browse files
change w read
parent
cf01f064
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
15 deletions
+21
-15
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+6
-4
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+1
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+14
-11
No files found.
example/ck_tile/17_fused_moe_general/main.cpp
View file @
40182e94
...
...
@@ -129,7 +129,7 @@ auto create_args(int argc, char* argv[])
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"t"
,
"128"
,
"num input tokens"
)
.
insert
(
"e"
,
"32"
,
"num of experts"
)
.
insert
(
"k"
,
"
2
"
,
"topk"
)
.
insert
(
"k"
,
"
5
"
,
"topk"
)
.
insert
(
"h"
,
"8192"
,
"hidden_size of this model"
)
.
insert
(
"i"
,
"8192"
,
"intermediate_size between 2 gemms of FFN"
)
.
insert
(
"stride"
,
"-1"
,
"stride per row, if -1 then equal to hidden_size"
)
...
...
@@ -285,6 +285,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
// ck_tile::FillConstant<TopkWeightDataType>{0.1}(topk_weight_host);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
,
seed
,
true
}(
topk_weight_host
);
}
...
...
@@ -308,6 +309,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
FillUniformDistribution
<
GScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sg_host
);
ck_tile
::
FillUniformDistribution
<
DScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sd_host
);
ck_tile
::
FillUniformDistribution
<
YSmoothScaleDataType
>
{
-
.5
f
,
.5
f
,
seed
,
true
}(
sy_host
);
// ck_tile::FillConstant<TopkWeightDataType>{0.5}(topk_weight_host);
ck_tile
::
FillUniformDistribution
<
TopkWeightDataType
>
{
0.0
f
,
1.0
f
,
seed
,
true
}(
topk_weight_host
);
}
...
...
@@ -397,11 +399,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
// output_matrix_2d(a_host, tokens, hidden_size);
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
std
::
cout
<<
num_sorted_tiles_host
<<
std
::
endl
;
//
std::cout << num_sorted_tiles_host << std::endl;
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
// output_matrix_3d(d_host, experts, hidden_size, shared_intermediate_size_1);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
//
std::cout << topk_weight_host << std::endl;
//
std::cout << sorted_expert_ids_host << std::endl;
std
::
cout
<<
topk_weight_host
<<
std
::
endl
;
std
::
cout
<<
sorted_weight_host
<<
std
::
endl
;
// done, preparing GPU buffer
...
...
include/ck_tile/host/reference/reference_fused_moe.hpp
View file @
40182e94
...
...
@@ -171,6 +171,7 @@ void reference_fused_moe(
// printf("in:%d, %f\t", i_n, acc);
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
}
(
void
)
weight
;
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
40182e94
...
...
@@ -170,6 +170,16 @@ struct FusedMoeGemmPipeline_General
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
constexpr
auto
w_dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
s_acc
.
get_tile_distribution
().
get_static_tile_distribution_encoding
(),
sequence
<
1
>
{}));
auto
w_global_to_dram_window
=
make_tile_window
(
w_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{}),
w_window_
.
get_window_origin
(),
w_dstr
);
auto
w
=
load_tile
(
w_global_to_dram_window
);
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
// block_sync_load_raw();
...
...
@@ -250,16 +260,6 @@ struct FusedMoeGemmPipeline_General
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
auto
o_acc
=
OaccBlockTileType
{};
constexpr
auto
w_dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
s_acc
.
get_tile_distribution
().
get_static_tile_distribution_encoding
(),
sequence
<
1
>
{}));
auto
w_global_to_dram_window
=
make_tile_window
(
w_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{}),
w_window_
.
get_window_origin
(),
w_dstr
);
auto
w
=
load_tile
(
w_global_to_dram_window
);
float
weight
=
type_convert
<
float
>
(
w
.
get_thread_buffer
()[
0
]);
#if 0
constexpr index_t w_buffer_size = decltype(w)::get_thread_buffer_size();
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
...
...
@@ -294,7 +294,7 @@ struct FusedMoeGemmPipeline_General
// d data
auto
d_global_to_dram_window
=
make_tile_window
(
d_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_N
0
>
{},
number
<
BlockShape
::
Block_K
0
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_N
1
>
{},
number
<
BlockShape
::
Block_K
1
>
{}),
d_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>());
auto
d
=
load_tile
(
d_global_to_dram_window
);
...
...
@@ -339,6 +339,8 @@ struct FusedMoeGemmPipeline_General
}
}
};
float
weight
=
type_convert
<
float
>
(
w
.
get_thread_buffer
()[
0
]);
constexpr
index_t
kN1
=
BlockShape
::
Block_N1
;
const
index_t
n1_loops
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
kN1
);
index_t
iCounter1
=
n1_loops
-
1
;
...
...
@@ -382,6 +384,7 @@ struct FusedMoeGemmPipeline_General
#endif
}
// store_tile(o_window_, a_dram_block);
ignore
=
weight
;
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment