Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e97fdbc3
Commit
e97fdbc3
authored
Dec 20, 2024
by
letaoqin
Browse files
change gather index adaptor
parent
727f201d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
12 additions
and
22 deletions
+12
-22
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+1
-1
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+9
-9
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+0
-4
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+2
-8
No files found.
example/ck_tile/17_fused_moe_general/main.cpp
View file @
e97fdbc3
...
...
@@ -501,7 +501,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto
c_dev
=
c_buf
.
ToHost
<
ADataType
>
();
std
::
cout
<<
std
::
endl
;
// std::cout << c_dev << std::endl;
std
::
cout
<<
o_dev
<<
std
::
endl
;
//
std::cout << o_dev << std::endl;
// int count = 0;
// std::cout << "[";
// for(int i = 0; i < tokens; i++)
...
...
include/ck_tile/core/algorithm/indexing_adaptor.hpp
View file @
e97fdbc3
...
...
@@ -81,7 +81,7 @@ struct indexing_adaptor
#if Using_Gather
pre_up_index_
=
idx_up
[
number
<
0
>
{}];
pre_low_index_
=
idx_low
(
number
<
0
>
{});
#if
1
#if
0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
...
...
@@ -100,30 +100,30 @@ struct indexing_adaptor
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
(
void
)
idx_up
;
#if !Using_Gather
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
idx_low
+=
idx_diff_low
;
#else
int
up_index
=
idx_diff_up
[
number
<
0
>
{}]
+
pre_up_index_
;
int
low_index
=
*
(
cached_idx_
+
up_index
);
idx_low
(
number
<
0
>
{})
=
low_index
;
idx_diff_low
(
number
<
0
>
{})
=
low_index
-
pre_low_index_
;
pre_up_index_
=
up_index
;
pre_low_index_
=
low_index
;
#if 1
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf
(
"
\n
index form %d to %d, idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d
\n
"
,
printf("\n
end
index form %d to %d, idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d
, pre_low_index_: %d pre_up_index_: %d
\n",
up_index,
low_index,
idx_diff_low(number<0>{}),
idx_diff_up[number<0>{}],
idx_low(number<0>{}),
idx_up
.
at
(
number
<
0
>
{}));
idx_up.at(number<0>{}),
pre_low_index_,
pre_up_index_);
}
#endif
#endif
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
e97fdbc3
...
...
@@ -274,7 +274,6 @@ struct FusedMoeGemmPipeline_General
}
}
#endif
ignore
=
w
;
// y data
auto
bridge_llds_win
=
make_tile_window
(
bridge_lds_view
,
...
...
@@ -339,9 +338,6 @@ struct FusedMoeGemmPipeline_General
type_convert
<
float
>
(
o1
.
get_thread_buffer
()[
i
]));
});
});
// tile_elementwise_inout([&weight](auto& x) { x = x *
// type_convert<float>(weight); },
// o0);
auto
o
=
cast_tile
<
ODataType
>
(
o0
);
update_tile
(
o_window_
,
o
);
// restore pos
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
e97fdbc3
...
...
@@ -21,17 +21,11 @@ namespace ck_tile {
struct
FusedMoeGemmPipelineGeneralPolicy
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetAsyncCopyDwords
()
{
// TODO: always 1 dword
return
2
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_A
()
{
// using async
constexpr
index_t
copy_bytes
=
4
*
GetAsyncCopyDwords
()
;
constexpr
index_t
copy_bytes
=
8
;
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
ADataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
...
...
@@ -196,7 +190,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
1
,
1
,
32
>
,
sequence
<
2
,
16
>>
,
tuple
<
sequence
<
1
,
2
,
16
>
,
sequence
<
4
,
8
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment