Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9ec586fc
Commit
9ec586fc
authored
Nov 27, 2024
by
letaoqin
Browse files
change a matrix lds desc
parent
66efcf96
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
84 deletions
+18
-84
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
...tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
+1
-1
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
...16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
+1
-1
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
...ed_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
+16
-82
No files found.
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_api.cpp
View file @
9ec586fc
...
...
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if
(
t
.
prec_i
==
"bf16"
&&
t
.
prec_w
==
"bf16"
&&
t
.
prec_o
==
"bf16"
&&
t
.
prec_st
==
"fp32"
&&
t
.
prec_sw
==
"fp32"
&&
t
.
prec_sq
==
"fp32"
&&
t
.
prec_kw
==
"fp32"
&&
t
.
block_m
==
32
&&
t
.
gate_only
==
1
)
{
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
128
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
using
t_
=
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
128
>
,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
;
r
=
fused_moegemm_
<
t_
>
(
s
,
a
);
}
// clang-format on
...
...
example/ck_tile/16_fused_moe_general/instances/fused_moegemm_bf16_m32.cpp
View file @
9ec586fc
...
...
@@ -8,7 +8,7 @@
// clang-format off
template
float
fused_moegemm_
<
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
128
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
fmoe_
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
float
,
float
,
float
,
S
<
32
,
128
,
32
,
128
>,
S
<
1
,
4
,
1
>
,
S
<
32
,
32
,
8
>
,
1
,
0
>
>
(
const
ck_tile
::
stream_config
&
s
,
fused_moegemm_args
a
);
// clang-format on
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp
View file @
9ec586fc
...
...
@@ -201,7 +201,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
index_t
K2
=
S_
::
Warp_K0
;
constexpr
index_t
K1
=
get_warp_size
()
/
S_
::
Warp_N0
;
constexpr
index_t
K0
=
kKPerBlock
/
(
K1
*
K2
)
;
constexpr
index_t
K0
=
S_
::
Repeat_K0
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
...
...
@@ -277,93 +277,27 @@ struct FusedMoeGemmPipelineGeneralPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
kK1
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kK0
=
Block_K
/
kK1
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
KVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
KPack
;
// pad between warps
static_assert
(
Block_K
%
kK1
==
0
);
static_assert
(
Block_K
%
KVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
KVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kK0
>
{},
number
<
Block_M
>
{},
number
<
kK1
>
{}),
make_tuple
(
number
<
(
Block_M
+
1
)
*
kK1
>
{},
number
<
kK1
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
// make_pass_through_transform(),
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
Block_M
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kK0
>
{},
number
<
kK1
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
// make_pass_through_transform(number<NumIssues>{}),
// make_pass_through_transform(number<NumWarps>{}),
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
return
a_lds_block_desc
;
}
template
<
typename
Problem
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment