Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
15e76415
Commit
15e76415
authored
Dec 12, 2024
by
letaoqin
Browse files
add padding to O
parent
b885995c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
23 deletions
+28
-23
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
...ile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
+6
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+22
-22
No files found.
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp
View file @
15e76415
...
@@ -362,9 +362,14 @@ struct FusedMoeGemmGlKernel
...
@@ -362,9 +362,14 @@ struct FusedMoeGemmGlKernel
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
o_
window_
=
make_tile_windo
w
(
auto
o_
padd_view_
=
pad_tensor_vie
w
(
o_scatter_view_
,
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
sequence
<
true
,
0
>
{});
auto
o_window_
=
make_tile_window
(
o_padd_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
{
idx_m0
,
0
});
{
idx_m0
,
0
});
return
o_window_
;
return
o_window_
;
}();
}();
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
15e76415
...
@@ -308,6 +308,27 @@ struct FusedMoeGemmPipeline_General
...
@@ -308,6 +308,27 @@ struct FusedMoeGemmPipeline_General
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
ignore
=
o_alds_win
;
ignore
=
o_alds_win
;
auto
save_o
=
[
&
]()
{
if
(
blockIdx
.
x
==
0
&&
(
blockIdx
.
y
==
0
||
blockIdx
.
y
==
1
)
&&
blockIdx
.
z
==
0
)
{
if
(
threadIdx
.
x
<
64
)
{
auto
o0
=
load_tile
(
o_olds_win
);
for
(
int
step
=
1
;
step
<
4
;
step
++
)
{
move_tile_window
(
o_olds_win
,
{
32
,
0
});
auto
o1
=
load_tile
(
o_olds_win
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
o0
.
get_thread_buffer
()(
i
)
=
type_convert
<
ODataType
>
(
type_convert
<
float
>
(
o0
.
get_thread_buffer
()[
i
])
+
type_convert
<
float
>
(
o1
.
get_thread_buffer
()[
i
]));
}
}
update_tile
(
o_window_
,
o0
);
}
}
};
constexpr
index_t
kN1
=
BlockShape
::
Block_N1
;
constexpr
index_t
kN1
=
BlockShape
::
Block_N1
;
const
index_t
n1_loops
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
kN1
);
const
index_t
n1_loops
=
ck_tile
::
integer_divide_ceil
(
hidden_size
,
kN1
);
index_t
iCounter1
=
n1_loops
-
1
;
index_t
iCounter1
=
n1_loops
-
1
;
...
@@ -336,30 +357,9 @@ struct FusedMoeGemmPipeline_General
...
@@ -336,30 +357,9 @@ struct FusedMoeGemmPipeline_General
// tile_elementwise_inout(
// tile_elementwise_inout(
// [&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
// [&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
#if 0
PrintMem(o, "O", 65);
#endif
store_tile
(
o_alds_win
,
o
);
store_tile
(
o_alds_win
,
o
);
block_sync_lds
();
block_sync_lds
();
if
(
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
save_o
();
{
if
(
threadIdx
.
x
<
64
)
{
auto
o0
=
load_tile
(
o_olds_win
);
for
(
int
step
=
1
;
step
<
4
;
step
++
)
{
move_tile_window
(
o_olds_win
,
{
32
,
0
});
auto
o1
=
load_tile
(
o_olds_win
);
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
o0
.
get_thread_buffer
()(
i
)
=
type_convert
<
ODataType
>
(
type_convert
<
float
>
(
o0
.
get_thread_buffer
()[
i
])
+
type_convert
<
float
>
(
o1
.
get_thread_buffer
()[
i
]));
}
}
update_tile
(
o_window_
,
o0
);
}
}
// store_tile(o_window_, o);
// store_tile(o_window_, o);
#if 0
#if 0
PrintMem(o,"O");
PrintMem(o,"O");
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment