Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b2030e34
Commit
b2030e34
authored
Nov 28, 2024
by
letaoqin
Browse files
s_acc data to lds to shuffle
parent
1d89463c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
42 deletions
+29
-42
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
+24
-17
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+5
-25
No files found.
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp
View file @
b2030e34
...
@@ -122,7 +122,7 @@ struct FusedMoeGemmPipeline_General
...
@@ -122,7 +122,7 @@ struct FusedMoeGemmPipeline_General
g_window_
.
get_window_origin
(),
g_window_
.
get_window_origin
(),
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>());
//
Block GEMM
//
gemm gate
constexpr
auto
gemm_0
=
Policy
::
template
GetBlockGemm0
<
Problem
>();
constexpr
auto
gemm_0
=
Policy
::
template
GetBlockGemm0
<
Problem
>();
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
auto
s_acc
=
SaccBlockTileType
{};
...
@@ -135,36 +135,43 @@ struct FusedMoeGemmPipeline_General
...
@@ -135,36 +135,43 @@ struct FusedMoeGemmPipeline_General
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
auto
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
clear_tile
(
s_acc
);
// initialize C
clear_tile
(
s_acc
);
// initialize C
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
constexpr
index_t
kK0
=
BlockShape
::
Block_K0
;
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
intermediate_size
,
kK0
);
const
index_t
k0_loops
=
ck_tile
::
integer_divide_ceil
(
intermediate_size
,
kK0
);
index_t
iCounter
=
k0_loops
-
1
;
index_t
iCounter
=
k0_loops
-
1
;
//gemm 0
while
(
iCounter
>
0
)
while
(
iCounter
>
0
)
{
{
block_sync_lds
();
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
block_sync_lds
();
block_sync_lds
();
move_tile_window
(
a_global_to_dram_window
,
{
0
,
kK0
});
move_tile_window
(
a_global_to_dram_window
,
{
0
,
kK0
});
move_tile_window
(
g_global_to_dram_window
,
{
0
,
kK0
});
move_tile_window
(
g_global_to_dram_window
,
{
0
,
kK0
});
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
a_dram_block
=
load_tile
(
a_global_to_dram_window
);
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
g_dram_block
=
load_tile
(
g_global_to_dram_window
);
store_tile
(
a_lds_win
,
a_dram_block
);
store_tile
(
a_lds_win
,
a_dram_block
);
iCounter
--
;
iCounter
--
;
}
}
// tail
// tail
{
{
block_sync_lds
();
block_sync_lds
();
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
gemm_0
(
s_acc
,
a_lds_win
,
g_dram_block
);
}
}
// move sacc to LDS
//move sacc to LDS
auto
bridge_lds_view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>());
ignore
=
g_dram_block
;
auto
bridge_lds_win
=
make_tile_window
(
bridge_lds_view
,
Policy
::
template
MakeBridgeLdsBlockDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
y_pre
=
cast_tile
<
YDataType
>
(
s_acc
);
store_tile
(
bridge_lds_win
,
y_pre
);
// gemm1 down
ignore
=
bridge_lds_win
;
store_tile
(
o_window_
,
a_dram_block
);
store_tile
(
o_window_
,
a_dram_block
);
#if 0
#if 0
//check a matrix gather right or not
//check a matrix gather right or not
...
...
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
b2030e34
...
@@ -102,11 +102,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -102,11 +102,8 @@ struct FusedMoeGemmPipelineGeneralPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_Bridge
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_Bridge
()
{
{
constexpr
auto
bridge_sld_desc
=
MakeBridgeLdsLoadDesc
<
Problem
>
();
constexpr
auto
bridge_lds_desc
=
MakeBridgeLdsBlockDesc
<
Problem
>
();
constexpr
auto
bridge_sst_desc
=
MakeBridgeLdsStoreDesc
<
Problem
>
();
return
bridge_lds_desc
.
get_element_space_size
();
static_assert
(
bridge_sld_desc
.
get_element_space_size
()
==
bridge_sst_desc
.
get_element_space_size
());
return
bridge_sld_desc
.
get_element_space_size
();
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -296,30 +293,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
...
@@ -296,30 +293,13 @@ struct FusedMoeGemmPipelineGeneralPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLds
Load
Desc
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLds
Block
Desc
()
{
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
constexpr
index_t
KPad
=
0
;
// pad between warps
constexpr
index_t
KPad
=
0
;
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
KPad
>
{},
number
<
1
>
{}),
number
<
KVector
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreDesc
()
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
0
;
// KVector; // pad between warps
constexpr
auto
desc
=
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment