Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
580d93dc
Commit
580d93dc
authored
Dec 17, 2024
by
letaoqin
Browse files
rewrite save o
parent
d4a0a8ee
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
25 deletions
+25
-25
example/ck_tile/17_fused_moe_general/main.cpp
example/ck_tile/17_fused_moe_general/main.cpp
+1
-1
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+8
-6
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+5
-13
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+11
-5
No files found.
example/ck_tile/17_fused_moe_general/main.cpp
View file @
580d93dc
...
@@ -500,8 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -500,8 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
o_dev
=
o_buf
.
ToHost
<
ODataType
>
();
auto
c_dev
=
c_buf
.
ToHost
<
ADataType
>
();
auto
c_dev
=
c_buf
.
ToHost
<
ADataType
>
();
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
// std::cout << o_dev << std::endl;
// std::cout << c_dev << std::endl;
// std::cout << c_dev << std::endl;
std
::
cout
<<
o_dev
<<
std
::
endl
;
// int count = 0;
// int count = 0;
// std::cout << "[";
// std::cout << "[";
// for(int i = 0; i < tokens; i++)
// for(int i = 0; i < tokens; i++)
...
...
include/ck_tile/core/algorithm/indexing_adaptor.hpp
View file @
580d93dc
...
@@ -81,7 +81,7 @@ struct indexing_adaptor
...
@@ -81,7 +81,7 @@ struct indexing_adaptor
#if Using_Gather
#if Using_Gather
pre_up_index_
=
idx_up
[
number
<
0
>
{}];
pre_up_index_
=
idx_up
[
number
<
0
>
{}];
pre_low_index_
=
idx_low
(
number
<
0
>
{});
pre_low_index_
=
idx_low
(
number
<
0
>
{});
#if
0
#if
1
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
{
printf
(
"
\n
first index from %d to %d
\n
"
,
idx_up
[
number
<
0
>
{}],
idx_low
(
number
<
0
>
{}));
printf
(
"
\n
first index from %d to %d
\n
"
,
idx_up
[
number
<
0
>
{}],
idx_low
(
number
<
0
>
{}));
...
@@ -93,8 +93,8 @@ struct indexing_adaptor
...
@@ -93,8 +93,8 @@ struct indexing_adaptor
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
/*
idx_low
*/
,
LowIdx
&
idx_low
,
const
UpIdx
&
/*
idx_up
*/
)
const
const
UpIdx
&
idx_up
)
const
{
{
// TODO: nonthing changed here
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
...
@@ -109,14 +109,16 @@ struct indexing_adaptor
...
@@ -109,14 +109,16 @@ struct indexing_adaptor
pre_up_index_
=
up_index
;
pre_up_index_
=
up_index
;
pre_low_index_
=
low_index
;
pre_low_index_
=
low_index
;
#if
0
#if
1
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
if
(
threadIdx
.
x
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
{
printf("\n index form %d to %d,
diff from %d to %d
\n",
printf
(
"
\n
index form %d to %d,
idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d
\n
"
,
up_index
,
up_index
,
low_index
,
low_index
,
idx_diff_low
(
number
<
0
>
{}),
idx_diff_up
[
number
<
0
>
{}],
idx_diff_up
[
number
<
0
>
{}],
idx_diff_low(number<0>{}));
idx_low
(
number
<
0
>
{}),
idx_up
.
at
(
number
<
0
>
{}));
}
}
#endif
#endif
#endif
#endif
...
...
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
580d93dc
...
@@ -252,13 +252,6 @@ struct FusedMoeGemmGlKernel
...
@@ -252,13 +252,6 @@ struct FusedMoeGemmGlKernel
index_t
idx_n0
=
index_t
idx_n0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_N0
);
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_N0
);
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
// const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
// // position
// auto topk_weight =
// reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const
index_t
*
sorted_token_ids_ptr
=
const
index_t
*
sorted_token_ids_ptr
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
);
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
);
...
@@ -375,18 +368,17 @@ struct FusedMoeGemmGlKernel
...
@@ -375,18 +368,17 @@ struct FusedMoeGemmGlKernel
}();
}();
const
auto
w_window
=
[
&
]()
{
const
auto
w_window
=
[
&
]()
{
const
TopkWeightDataType
*
w_ptr
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
);
const
TopkWeightDataType
*
w_ptr
=
const
auto
w_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
);
const
auto
w_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
w_ptr
,
w_ptr
,
make_tuple
(
kargs
.
max_num_tokens_padded
),
make_tuple
(
kargs
.
max_num_tokens_padded
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
const
auto
w_window_
=
make_tile_window
(
const
auto
w_window_
=
w_view_
,
make_tile_window
(
w_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{}),
{
idx_m0
});
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{}),
{
idx_m0
});
return
w_window_
;
return
w_window_
;
}();
}();
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
580d93dc
...
@@ -348,22 +348,28 @@ struct FusedMoeGemmPipeline_General
...
@@ -348,22 +348,28 @@ struct FusedMoeGemmPipeline_General
while
(
iCounter1
>
0
)
while
(
iCounter1
>
0
)
{
{
clear_tile
(
o_acc
);
clear_tile
(
o_acc
);
block_sync_lds
();
block_sync_lds
_direct_load
();
gemm_1
(
o_acc
,
y
,
d
);
gemm_1
(
o_acc
,
y
,
d
);
block_sync_lds
();
move_tile_window
(
d_global_to_dram_window
,
{
kN1
,
0
});
move_tile_window
(
d_global_to_dram_window
,
{
kN1
,
0
});
d
=
load_tile
(
d_global_to_dram_window
);
d
=
load_tile
(
d_global_to_dram_window
);
// move out window and save data
// move out window and save data
tile_elementwise_inout
([
&
weight
](
auto
&
x
)
{
x
=
x
*
type_convert
<
float
>
(
weight
);
},
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_window_
,
o
);
store_tile
(
o_alds_win
,
o
);
move_tile_window
(
o_window_
,
{
kN1
,
0
});
block_sync_lds
();
save_o
();
move_tile_window
(
o_window_
,
{
0
,
kN1
});
iCounter1
--
;
iCounter1
--
;
}
}
// tail
// tail
{
{
clear_tile
(
o_acc
);
clear_tile
(
o_acc
);
block_sync_lds
();
block_sync_lds
_direct_load
();
gemm_1
(
o_acc
,
y
,
d
);
gemm_1
(
o_acc
,
y
,
d
);
// block_sync_lds();
// block_sync_lds();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment