Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
727f201d
Commit
727f201d
authored
Dec 18, 2024
by
letaoqin
Browse files
change save o to lds data type to float
parent
28252273
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
20 deletions
+25
-20
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+25
-20
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
727f201d
...
...
@@ -256,9 +256,9 @@ struct FusedMoeGemmPipeline_General
constexpr
auto
w_dstr
=
make_static_tile_distribution
(
detail
::
make_reduce_tile_distribution_encoding
(
s_acc
.
get_tile_distribution
().
get_static_tile_distribution_encoding
(),
sequence
<
1
>
{}));
auto
w_global_to_dram_window
=
make_tile_window
(
w_window_
.
get_bottom_tensor_view
(),
s_acc
.
get_tile_distribution
().
get_static_tile_distribution_encoding
(),
sequence
<
1
>
{}));
auto
w_global_to_dram_window
=
make_tile_window
(
w_window_
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{}),
w_window_
.
get_window_origin
(),
w_dstr
);
...
...
@@ -307,9 +307,10 @@ struct FusedMoeGemmPipeline_General
PrintMem(d,"D",0);
#endif
// add to LDS
CK_TILE_LDS_ADDR
float
*
smem_3
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
float
*>
(
smem
);
auto
o_lds_view
=
make_naive_tensor_view
<
address_space_enum
::
lds
,
memory_operation_enum
::
atomic_add
>
(
smem_
0
,
make_naive_tensor_view
<
address_space_enum
::
lds
,
memory_operation_enum
::
set
>
(
smem_
3
,
make_tuple
(
number
<
128
>
{},
number
<
32
>
{}),
make_tuple
(
32
,
1
),
number
<
8
>
{},
...
...
@@ -333,12 +334,16 @@ struct FusedMoeGemmPipeline_General
move_tile_window
(
o_olds_win
,
{
32
,
0
});
auto
o1
=
load_tile
(
o_olds_win
);
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
o0
.
get_thread_buffer
()(
i
)
=
type_convert
<
ODataType
>
(
type_convert
<
float
>
(
o0
.
get_thread_buffer
()[
i
])
+
o0
.
get_thread_buffer
()(
i
)
=
type_convert
<
float
>
(
type_convert
<
float
>
(
o0
.
get_thread_buffer
()[
i
])
+
type_convert
<
float
>
(
o1
.
get_thread_buffer
()[
i
]));
});
});
update_tile
(
o_window_
,
o0
);
// tile_elementwise_inout([&weight](auto& x) { x = x *
// type_convert<float>(weight); },
// o0);
auto
o
=
cast_tile
<
ODataType
>
(
o0
);
update_tile
(
o_window_
,
o
);
// restore pos
move_tile_window
(
o_olds_win
,
{
-
32
*
(
BlockShape
::
Repeat_K1
-
1
),
0
});
}
...
...
@@ -359,8 +364,8 @@ struct FusedMoeGemmPipeline_General
// move out window and save data
tile_elementwise_inout
([
&
weight
](
auto
&
x
)
{
x
=
x
*
type_convert
<
float
>
(
weight
);
},
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_alds_win
,
o
);
//
auto o = cast_tile<ODataType>(o_acc);
store_tile
(
o_alds_win
,
o
_acc
);
block_sync_lds
();
save_o
();
...
...
@@ -375,10 +380,10 @@ struct FusedMoeGemmPipeline_General
gemm_1
(
o_acc
,
y
,
d
);
// block_sync_lds();
tile_elementwise_inout
(
[
&
weight
](
auto
&
x
)
{
x
=
x
*
type_convert
<
float
>
(
weight
);
},
o_acc
);
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_alds_win
,
o
);
tile_elementwise_inout
(
[
&
weight
](
auto
&
x
)
{
x
=
x
*
type_convert
<
float
>
(
weight
);
},
o_acc
);
//
auto o = cast_tile<ODataType>(o_acc);
store_tile
(
o_alds_win
,
o
_acc
);
block_sync_lds
();
save_o
();
// store_tile(o_window_, o);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment