Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1561fc22
Commit
1561fc22
authored
Nov 20, 2024
by
“letaoqin”
Browse files
change indexing adapter to gather matrix
parent
1caa8198
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
5 deletions
+59
-5
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+11
-2
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+25
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
+23
-1
No files found.
example/ck_tile/16_fused_moe_general/main.cpp
View file @
1561fc22
...
...
@@ -246,8 +246,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << std::endl;
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "Line " << i << "\t";
// for(int j = 0; j < hidden_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(a_host(i,j)) << "\t";
// }
// std::cout << std::endl;
// }
std
::
cout
<<
sorted_token_ids_host
<<
std
::
endl
;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
// std::cout << topk_weight_host << std::endl;
...
...
include/ck_tile/core/algorithm/indexing_adaptor.hpp
View file @
1561fc22
...
...
@@ -65,6 +65,8 @@ struct indexing_adaptor
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor
(
const
IndexingType
*
idx
)
:
cached_idx_
(
idx
)
{}
const
IndexingType
*
cached_idx_
;
mutable
index_t
preUpIndex
=
0
;
mutable
index_t
preLowIndex
=
0
;
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
...
...
@@ -74,6 +76,13 @@ struct indexing_adaptor
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
*
(
cached_idx_
+
idx_up
[
number
<
0
>
{}]);
preUpIndex
=
idx_up
[
number
<
0
>
{}];
preLowIndex
=
idx_low
(
number
<
0
>
{});
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
printf
(
"
\n
first index from %d to %d
\n
"
,
idx_up
[
number
<
0
>
{}],
idx_low
(
number
<
0
>
{}));
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -86,8 +95,22 @@ struct indexing_adaptor
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
int
up_index
=
idx_diff_up
[
number
<
0
>
{}]
+
preUpIndex
;
int
low_index
=
*
(
cached_idx_
+
up_index
);
idx_diff_low
(
number
<
0
>
{})
=
low_index
-
preLowIndex
;
preUpIndex
=
up_index
;
preLowIndex
=
low_index
;
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
printf
(
"
\n
index form %d to %d, diff from %d to %d
\n
"
,
up_index
,
low_index
,
idx_diff_up
[
number
<
0
>
{}],
idx_diff_low
(
number
<
0
>
{}));
}
// pass the diff to lower, but not changing the actually index
}
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_gl.hpp
View file @
1561fc22
...
...
@@ -97,13 +97,35 @@ struct FusedMoeGemmPipeline_FlatmmGl
index_t
hidden_size
,
index_t
intermediate_size
)
{
ignore
=
a_window_
;
ignore
=
g_window_
;
ignore
=
d_window_
;
ignore
=
o_window_
;
ignore
=
smem
;
ignore
=
hidden_size
;
ignore
=
intermediate_size
;
auto
a_copy_dram_window
=
make_tile_window
(
a_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
a_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
a_dram
=
load_tile
(
a_copy_dram_window
);
//check a matrix gather right or not
constexpr
auto
a_spans
=
decltype
(
a_dram
)
::
get_distributed_spans
();
int
counter
=
0
;
sweep_tile_span
(
a_spans
[
number
<
0
>
{}],
[
&
](
auto
idxm
)
{
sweep_tile_span
(
a_spans
[
number
<
1
>
{}],
[
&
](
auto
idxk
){
constexpr
auto
i_j_idx
=
make_tuple
(
idxm
,
idxk
);
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
){
counter
=
counter
+
1
;
index_t
idm_0
=
idxm
.
impl_
.
at
(
0
);
index_t
idn_0
=
idxk
.
impl_
.
at
(
0
);
printf
(
"in A idm is %d , idn_ is %d , counter is %d, value is: %f
\n
"
,
idm_0
,
idn_0
,
counter
,
ck_tile
::
type_convert
<
float
>
(
a_dram
(
i_j_idx
)));
}
});
});
ignore
=
a_spans
;
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment