Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7881eff9
Commit
7881eff9
authored
Nov 28, 2024
by
letaoqin
Browse files
gemm down
parent
6a03c66f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
19 deletions
+50
-19
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+12
-5
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+38
-14
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
7881eff9
...
...
@@ -98,8 +98,6 @@ struct FusedMoeGemmPipeline_General
index_t
hidden_size
,
index_t
intermediate_size
)
{
ignore
=
d_window_
;
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
auto
a_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsBlockDesc_A
<
Problem
>());
...
...
@@ -194,12 +192,21 @@ struct FusedMoeGemmPipeline_General
while
(
iCounter1
>
0
)
{
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
block_sync_lds
();
move_tile_window
(
d_global_to_dram_window
,
{
kN1
,
0
});
d
=
load_tile
(
d_global_to_dram_window
);
iCounter1
--
;
}
ignore
=
y
;
ignore
=
d
;
store_tile
(
o_window_
,
a_dram_block
);
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
y
,
d
);
}
auto
o
=
cast_tile
<
ODataType
>
(
o_acc
);
store_tile
(
o_window_
,
o
);
// store_tile(o_window_, a_dram_block);
#if 0
//check a matrix gather right or not
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
7881eff9
...
...
@@ -12,6 +12,8 @@
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
namespace
ck_tile
{
...
...
@@ -209,19 +211,41 @@ struct FusedMoeGemmPipelineGeneralPolicy
return
BlockGemmASmemBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
YDataType
,
typename
Problem
::
DDataType
,
typename
Problem
::
AccDataType
,
S_
::
BlockSize
,
TileGemmShape
<
typename
S_
::
BlockTile_1
,
typename
S_
::
WarpPerBlock_1
,
typename
S_
::
WarpTile_1
>>
;
constexpr
auto
warp_gemm
=
GetWarpGemm1
<
Problem
>
();
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
GDataType
,
typename
Problem
::
AccDataType
,
typename
S_
::
WarpPerBlock_1
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
constexpr
auto
d_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_
N
1
>
,
tuple
<
sequence
<
S_
::
Repeat_N1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
0
>>
,
tuple
<
sequence
<
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
d_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_
M
1
>
,
tuple
<
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
d_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
d_outer_dstr_enc
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
...
...
@@ -368,13 +392,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
// TODO: all waves a along different N, but same M
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_
M
1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
0
>>
,
tuple
<
sequence
<
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_
N
1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment