Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ef8e3620
Commit
ef8e3620
authored
Nov 25, 2024
by
letaoqin
Browse files
gather and scatter right
parent
eaf8e616
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
31 deletions
+33
-31
example/ck_tile/16_fused_moe_general/main.cpp
example/ck_tile/16_fused_moe_general/main.cpp
+16
-15
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+2
-2
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+7
-8
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+8
-6
No files found.
example/ck_tile/16_fused_moe_general/main.cpp
View file @
ef8e3620
...
...
@@ -60,15 +60,15 @@ auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype,
}
template
<
typename
IndexType
>
void
output_matrix_2d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
m
,
int
n
)
void
output_matrix_2d
(
ck_tile
::
HostTensor
<
IndexType
>&
data
,
int
m
,
int
n
)
{
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
m
;
i
++
)
{
std
::
cout
<<
"Line "
<<
i
<<
"
\t
"
;
for
(
int
j
=
0
;
j
<
n
;
j
++
)
{
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
i
,
j
))
<<
"
\t
"
;
std
::
cout
<<
ck_tile
::
type_convert
<
float
>
(
data
(
i
,
j
))
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
...
...
@@ -261,17 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host
.
mData
[
0
],
experts
,
block_m
);
// std::cout << std::endl;
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "Line " << i << "\t";
// for(int j = 0; j < hidden_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(a_host(i,j)) << "\t";
// }
// std::cout << std::endl;
// }
output_matrix_2d
(
a_host
,
tokens
,
hidden_size
);
// output_matrix_2d(a_host, tokens, hidden_size);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
// std::cout << sorted_expert_ids_host << std::endl;
...
...
@@ -381,7 +372,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_dev
,
o_host
,
std
::
string
(
"OUT Error: Incorrect results!"
),
rtol
,
atol
);
std
::
cout
<<
", valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
flush
;
output_matrix_2d
(
o_dev
,
tokens
,
hidden_size
);
// std::cout << std::endl;
// int count = 0;
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "Line " << i << "\t";
// for(int j = 0; j < hidden_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(o_dev(count++)) << "\t";
// }
// std::cout << std::endl;
// }
}
std
::
cout
<<
std
::
flush
<<
std
::
endl
;
...
...
include/ck_tile/core/algorithm/indexing_adaptor.hpp
View file @
ef8e3620
...
...
@@ -80,7 +80,7 @@ struct indexing_adaptor
pre_up_index_
=
idx_up
[
number
<
0
>
{}];
pre_low_index_
=
idx_low
(
number
<
0
>
{});
#if 0
if(threadIdx.x ==
65
&& blockIdx.x == 0 && blockIdx.y ==
1
&& blockIdx.z == 0)
if(threadIdx.x ==
0
&& blockIdx.x == 0 && blockIdx.y ==
0
&& blockIdx.z == 0)
{
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
}
...
...
@@ -105,7 +105,7 @@ struct indexing_adaptor
pre_up_index_
=
up_index
;
pre_low_index_
=
low_index
;
#if 0
if(threadIdx.x ==
65
&& blockIdx.x == 0 && blockIdx.y ==
1
&& blockIdx.z == 0)
if(threadIdx.x ==
0
&& blockIdx.x == 0 && blockIdx.y ==
0
&& blockIdx.z == 0)
{
printf("\n index form %d to %d, diff from %d to %d \n",
up_index,
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
ef8e3620
...
...
@@ -78,7 +78,7 @@ struct FusedMoeGemmPipeline_General
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_mat_a
,
smem_bridge
);
//return Policy::template GetSmemSize<Problem>();
//
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
...
...
@@ -108,7 +108,10 @@ struct FusedMoeGemmPipeline_General
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>());
auto
a_lds_win
=
make_tile_window
(
a_lds_view
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
auto
a_lds_win
=
make_tile_window
(
a_lds_view
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
auto
a_global_to_dram_window
=
make_tile_window
(
a_window_
.
get_bottom_tensor_view
(),
...
...
@@ -116,15 +119,11 @@ struct FusedMoeGemmPipeline_General
a_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
// auto o_win = make_tile_window_linear(
// o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
auto
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
store_tile
(
a_lds_win
,
a_dram_block
);
store_tile
(
o_window_
,
a_dram_block
);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
...
...
@@ -132,7 +131,7 @@ struct FusedMoeGemmPipeline_General
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxm, idxk);
if(threadIdx.x ==
65
&& blockIdx.x == 0 && blockIdx.y ==
1
&& blockIdx.z == 0)
if(threadIdx.x ==
0
&& blockIdx.x == 0 && blockIdx.y ==
0
&& blockIdx.z == 0)
{
counter = counter + 1;
index_t idm_0 = idxm.impl_.at(0);
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
ef8e3620
...
...
@@ -367,9 +367,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
// make_pass_through_transform(),
// make_pass_through_transform(),
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
@@ -400,10 +401,11 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
//make_pass_through_transform(number<NumIssues>{}),
//make_pass_through_transform(number<NumWarps>{}),
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
// make_pass_through_transform(number<NumIssues>{}),
// make_pass_through_transform(number<NumWarps>{}),
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment