Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
69114f25
Commit
69114f25
authored
Nov 29, 2024
by
letaoqin
Browse files
output sacc
parent
bb7c4112
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
23 deletions
+51
-23
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+1
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+30
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+20
-20
No files found.
example/ck_tile/16_fused_moe_general/main.cpp
View file @
69114f25
...
...
@@ -280,7 +280,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
block_m
);
// output_matrix_2d(a_host, tokens, hidden_size);
//
std::cout << sorted_token_ids_host << std::endl;
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
// std::cout << num_sorted_tiles_host << std::endl;
output_matrix_3d
(
g_host
,
experts
,
shared_intermediate_size_0
,
hidden_size
);
std
::
cout
<<
sorted_expert_ids_host
<<
std
::
endl
;
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
69114f25
...
...
@@ -156,14 +156,14 @@ struct FusedMoeGemmPipeline_General
// load g to register
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
#if
1
#if
0
{
constexpr auto a_spans = decltype(g_dram_block)::get_distributed_spans();
int counter = 0;
sweep_tile_span(a_spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
if
(
threadIdx
.
x
==
1
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
if(threadIdx.x ==
0
&& blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
counter = counter + 1;
index_t idn_0 = idxn.impl_.at(0);
...
...
@@ -208,6 +208,34 @@ struct FusedMoeGemmPipeline_General
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
}
#if 1
{
constexpr
auto
a_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
int
counter
=
0
;
//a_spans[0] = 1;
sweep_tile_span
(
a_spans
[
number
<
0
>
{}],
[
&
](
auto
idxm
)
{
sweep_tile_span
(
a_spans
[
number
<
1
>
{}],
[
&
](
auto
idxn
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idxn
,
idxn
);
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
counter
=
counter
+
1
;
index_t
idm_0
=
idxm
.
impl_
.
at
(
0
);
index_t
idn_0
=
idxn
.
impl_
.
at
(
0
);
index_t
idn_1
=
idxn
.
impl_
.
at
(
1
);
printf
(
"in A idn is %d , idn_0 is %d, idn_1 is %d, counter is %d, value is: "
"%f
\n
"
,
idm_0
,
idn_0
,
idn_1
,
counter
,
ck_tile
::
type_convert
<
float
>
(
s_acc
(
i_j_idx
)));
}
});
});
}
#endif
// move sacc to LDS
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
69114f25
...
...
@@ -173,26 +173,26 @@ struct FusedMoeGemmPipelineGeneralPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
{
//
using WG = decltype(GetWarpGemm0<Problem>());
//
using S_ = typename Problem::BlockShape;
//
static_assert(S_::WarpPerBlock_N0==4);
//
constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
//
sequence<S_::WarpPerBlock_M0>,
//
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
//
tuple<sequence<0, 1>>,
//
tuple<sequence<0, 1>>,
//
sequence<1, 2>,
//
sequence<0, 0>>{};
//
constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
//
g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
constexpr
auto
g_block_dstr_encode
=
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
1
,
4
,
32
>
,
sequence
<
4
,
2
,
4
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
0
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{};
using
WG
=
decltype
(
GetWarpGemm0
<
Problem
>
());
using
S_
=
typename
Problem
::
BlockShape
;
static_assert
(
S_
::
WarpPerBlock_N0
==
4
);
constexpr
auto
g_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M0
>
,
tuple
<
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>
,
sequence
<
S_
::
Repeat_K0
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
g_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
g_outer_dstr_enc
,
typename
WG
::
BWarpDstrEncoding
{});
//
constexpr auto g_block_dstr_encode = tile_distribution_encoding<
//
sequence<1>,
//
tuple<sequence<1, 4, 32>, sequence<4, 2, 4>>,
//
tuple<sequence<0, 1>, sequence<2, 1>>,
//
tuple<sequence<0, 1>, sequence<
1
,
2
>>,
//
sequence<1, 2, 2>,
//
sequence<0, 0, 2>>{};
return
make_static_tile_distribution
(
g_block_dstr_encode
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment