Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1ba8a08f
Commit
1ba8a08f
authored
Aug 23, 2024
by
carlushuang
Browse files
update tmp work
parent
bf214665
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
348 additions
and
267 deletions
+348
-267
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
+2
-3
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
.../ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
+18
-16
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
...e/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
+313
-234
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_traits.hpp
...le/ck_tile/05_moe/fused_moe/pipeline/fused_moe_traits.hpp
+15
-14
No files found.
example/ck_tile/05_moe/fused_moe/kernel/fused_moe_kernel.hpp
View file @
1ba8a08f
...
...
@@ -252,7 +252,7 @@ struct FusedMoeKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
//
__shared__ char smem_ptr[GetSmemSize()];
ck_tile
::
index_t
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
ck_tile
::
index_t
*>
(
kargs
.
num_sorted_tiles_ptr
));
ck_tile
::
index_t
tile_id
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
;);
...
...
@@ -436,8 +436,7 @@ struct FusedMoeKernel
u_gtile_window
,
d_gtile_window
,
o_gtile_window
,
scale
,
smem_ptr
);
scale
);
tile_id
+=
gridDim
.
x
;
}
...
...
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline.hpp
View file @
1ba8a08f
...
...
@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetAIndex
()
{
constexpr
auto
a_dist
=
Policy
::
template
Make
A
GlobalTileDistribution
<
Problem
>();
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution
_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
...
...
@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync
OGlobalTensorView
&
o_gtile_window_tmp
,
// const void * sorted_weight_ptr,
ScaleDataType
scale
,
void
*
smem_ptr
,
CK_TILE_LDS_ADDR
void
*
smem_0
,
CK_TILE_LDS_ADDR
void
*
smem_1
,
index_t
dim_size
,
index_t
hidden_size
)
{
...
...
@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync
make_tile_window
(
a_gtile_window_tmp
.
get_bottom_tensor_view
(),
a_gtile_window_tmp
.
get_window_lengths
(),
a_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
Make
A
GlobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_A
<
Problem
>());
auto
g_gtile_window
=
make_tile_window
(
g_gtile_window_tmp
.
get_bottom_tensor_view
(),
g_gtile_window_tmp
.
get_window_lengths
(),
g_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeG
G
lobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_G
<
Problem
>());
auto
u_gtile_window
=
make_tile_window
(
u_gtile_window_tmp
.
get_bottom_tensor_view
(),
u_gtile_window_tmp
.
get_window_lengths
(),
u_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
Make
U
GlobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_U
<
Problem
>());
auto
d_gtile_window
=
make_tile_window
(
d_gtile_window_tmp
.
get_bottom_tensor_view
(),
d_gtile_window_tmp
.
get_window_lengths
(),
d_gtile_window_tmp
.
get_window_origin
(),
Policy
::
template
Make
D
GlobalTileDistribution
<
Problem
>());
Policy
::
template
MakeGlobalTileDistribution
_D
<
Problem
>());
auto
o_gtile_window
=
make_tile_window
(
o_gtile_window_tmp
.
get_bottom_tensor_view
(),
...
...
@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync
auto
a_smem_ptr
=
reinterpret_cast
<
ADataType
*>
(
smem_ptr
)
+
a_smem_offset
;
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
a_smem_ptr
,
Policy
::
template
MakeALdsStoreBlockDescriptor
<
Problem
>()),
Policy
::
template
MakeALdsStoreBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
smem_0_window
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreBlockDescriptor_A
<
Problem
>().
get_lengths
(),
{
0
,
0
});
async_load_tile
_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{}))
,
a_gtile_window
);
async_load_tile
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})));
for
(
index_t
i_0
=
0
;
i_0
<
loops_0
;
i_0
++
)
{}
}
...
...
@@ -351,8 +353,8 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse
));
}
buffer_load_fence
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
buffer_load_fence
_raw
(
0
);
// rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return
o_acc
;
...
...
@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
_raw
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...
...
@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
_raw
(
k_dram_window
.
get_num_access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
...
...
@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
k0_loops
<=
2
)
__builtin_amdgcn_sched_barrier
(
0
);
async_load_fence
();
async_load_fence
_raw
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
...
...
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_pipeline_policy.hpp
View file @
1ba8a08f
...
...
@@ -14,7 +14,6 @@ namespace ck_tile {
struct
FusedMoePipelinePolicy
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetAsyncCopyDwords
()
{
// TODO:
...
...
@@ -22,7 +21,7 @@ struct FusedMoePipelinePolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
_
A
()
{
// using async
static
constexpr
index_t
copy_bytes
=
4
*
GetAsyncCopyDwords
();
...
...
@@ -32,54 +31,27 @@ struct FusedMoePipelinePolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentG
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
_
G
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
if
constexpr
(
Problem
::
Traits
::
GateUpPreShuffled
)
{
return
4
*
4
;
}
else
{
return
4
*
GetAsyncCopyDwords
();
}
}();
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
GDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentU
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
_
U
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
if
constexpr
(
Problem
::
Traits
::
GateUpPreShuffled
)
{
return
4
*
4
;
}
else
{
return
4
*
GetAsyncCopyDwords
();
}
}();
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
UDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentD
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment
_
D
()
{
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
if
constexpr
(
Problem
::
Traits
::
DownPreShuffled
)
{
return
4
*
4
;
}
else
{
return
4
*
GetAsyncCopyDwords
();
}
}();
static
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
static
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
DDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
...
...
@@ -93,29 +65,11 @@ struct FusedMoePipelinePolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackA
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack
_
A
()
{
return
GetSmemKPack
<
typename
Problem
::
ADataType
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackG
()
{
return
GetSmemKPack
<
typename
Problem
::
GDataType
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackU
()
{
return
GetSmemKPack
<
typename
Problem
::
UDataType
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPackD
()
{
return
GetSmemKPack
<
typename
Problem
::
DDataType
>
();
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK
()
{
...
...
@@ -206,102 +160,58 @@ struct FusedMoePipelinePolicy
// Caution: this will require global memory pre-shuffled to follow the mfma layout
// to maximize the L1/L2 channel while skip LDS
/*
(b) n0 n1 n2 k0 k1 k2
klanes
|
nr 4 kr 4 16 8
(b) n0 n1 k0 k1 n2 k2 -> kthreads
| |
V V
waves nlanes
klanes
|
nr kr 4 4 16 8
(b) n0 k0 n1 k1 n2 k2 -> kthreads
| |
V V
waves nlanes
*/
template
<
typename
BlockTile
,
typename
BlockWarps
,
typename
WarpGemm
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_MatrixCore_Swizzled_NxK
()
template
<
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
WavesPerBlock_N
,
index_t
WavesPerBlock_K
,
typename
WarpGemm
,
index_t
Alignment
,
FusedMoePermuteStyle
PermuteStyle
=
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_MatrixCore_Swizzled
()
{
static_assert
(
Alignment
%
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKPerLane
==
0
);
static_assert
(
BlockWarps
{}.
at
(
number
<
0
>
{})
==
1
&&
BlockWarps
{}.
at
(
number
<
2
>
{})
==
1
);
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
BlockWarps
{},
multiplies
{},
number
<
1
>
{});
constexpr
index_t
NPerBlock
=
BlockTile
{}.
at
(
number
<
1
>
{});
constexpr
index_t
KPerBlock
=
BlockTile
{}.
at
(
number
<
2
>
{});
constexpr
index_t
K2
=
Alignment
;
constexpr
index_t
N2
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
K1
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
N1
=
NumWarps
;
static_assert
(
NPerBlock
%
(
N1
*
N2
)
==
0
);
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
if
constexpr
(
PermuteStyle
==
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
)
{
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K0
=
KPerBlock
/
(
K1
*
K2
);
constexpr
index_t
N0
=
NPerBlock
/
(
N1
*
N2
);
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
constexpr
index_t
Nr_p
=
WavesPerBlock_N
;
constexpr
index_t
Kr_p
=
WavesPerBlock_K
;
constexpr
index_t
Nr_y
=
Nr
/
Nr_p
;
constexpr
index_t
Kr_y
=
Kr
/
Kr_p
;
if
constexpr
(
get_warp_size
()
<
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"not not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
// NOTE: no swap, but hard to avoid LDS bank conflict
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
// NOTE: swapped for LDS load bank conflict free
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_lan
,
M_wav
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
sequence
<
1
>
,
// 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple
<
sequence
<
Nr_y
,
Nr_p
>
,
sequence
<
Kr_y
,
Kr_p
>
,
sequence
<
Kw
,
Nw
,
Kv
>>
,
// Nr_p, Kr_p Kw Nw
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
,
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
,
1
>>
,
// Nr_y Kr_y Kv
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
2
>>
{});
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
A
GlobalTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution
_A
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignmentA
<
Problem
>
();
constexpr
index_t
Alignment
=
GetAlignment
_
A
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
kMPerBlock
,
kKPerBlock
,
NumWarps
,
...
...
@@ -309,42 +219,75 @@ struct FusedMoePipelinePolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeG
G
lobalTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution
_G
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_g
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignmentG
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
kNPerBlock
,
kKPerBlock
,
NumWarps
,
Alignment
>
();
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_u
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
1
>
{});
constexpr
index_t
WavesPerBlock_K
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_G
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
U
GlobalTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution
_U
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_u
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignmentU
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
kNPerBlock
,
kKPerBlock
,
NumWarps
,
Alignment
>
();
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_u
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
1
>
{});
constexpr
index_t
WavesPerBlock_K
=
Problem
::
Gemm0BlockWarps
{}
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm0
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_U
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
D
GlobalTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution
_D
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_d
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_y
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignmentD
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
kNPerBlock
,
kKPerBlock
,
NumWarps
,
Alignment
>
();
constexpr
auto
PermuteStype
=
Problem
::
Traits
::
PermuteStyle
;
if
constexpr
(
PermuteStype
==
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
)
{
constexpr
index_t
kNPerBlock
=
Problem
::
FusedMoeTileShape
::
kN_d
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_y
;
constexpr
index_t
WavesPerBlock_N
=
Problem
::
Gemm1BlockWarps
{}
::
at
(
number
<
1
>
{});
constexpr
index_t
WavesPerBlock_K
=
Problem
::
Gemm1BlockWarps
{}
::
at
(
number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
GetWarpGemm1
<
Problem
>
()
>
;
constexpr
index_t
Alignment
=
GetAlignment_D
<
Problem
>
();
return
MakeGlobalTileDistribution_MatrixCore_Swizzled
<
kNPerBlock
,
kKPerBlock
,
WavesPerBlock_N
,
WavesPerBlock_K
,
WarpGemm
,
Alignment
,
PermuteStype
>
();
}
}
template
<
index_t
MPerBlock
,
...
...
@@ -359,10 +302,8 @@ struct FusedMoePipelinePolicy
constexpr
index_t
kBlockSize
=
ck_tile
::
get_warp_size
()
*
NumWarps
;
// Problem::kBlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr
index_t
KVector
=
Alignment
;
// GetAlignmentK<Problem>(); // this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
constexpr
index_t
KVector
=
Alignment
;
// this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps
static_assert
(
warpSize
*
KVector
>=
KPerBlock
&&
warpSize
*
KVector
%
KPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
KPerBlock
/
KVector
;
// within a wave
...
...
@@ -402,77 +343,188 @@ struct FusedMoePipelinePolicy
return
lds_block_desc
;
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
KPack
,
index_t
Alignement
,
index_t
IBuf
=
0
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSmemStoreBlockDescriptor_SimpleMxK_Async
(
number
<
IBuf
>
=
number
<
0
>
{})
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreBlockDescriptor_A
()
{
constexpr
index_t
kBlockSize
=
ck_tile
::
get_warp_size
()
*
NumWarps
;
// Problem::kBlockSize;
// A async->LDS
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
// constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
// constexpr index_t Alignement = GetAlignmentK<Problem>(); // this is for global load
constexpr
index_t
kPad
=
KPack
;
// for async-copy, this pad is between warps. Optimize this for lds_read speed
static_assert
(
warpSize
*
Alignement
>=
KPerBlock
&&
warpSize
*
Alignement
%
KPerBlock
==
0
);
constexpr
index_t
LanesPerK
=
KPerBlock
/
Alignement
;
// how many lane (within a wave) to load K
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// how many groups (within a wave), they may load different N, but same K
constexpr
index_t
NumIssues
=
MPerBlock
/
(
LaneGroups
*
NumWarps
);
static_assert
(
NumIssues
==
MPerBlock
*
KPerBlock
/
(
BlockSize
*
Alignement
));
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor_with_offset
(
make_tuple
(
number
<
NumIssues
>
{},
// n0
number
<
LaneGroups
>
{},
// n1
number
<
NumWarps
>
{},
// n2
number
<
LanesPerK
>
{},
// k0
number
<
Alignement
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
Alignement
+
kPad
)
>
{},
number
<
KPerBlock
>
{},
number
<
warpSize
*
Alignement
+
kPad
>
{},
number
<
Alignement
>
{},
number
<
1
>
{}),
number
<
IBuf
*
GetSingleSmemElementSpaceSize
<
Problem
>
()
>
{},
number
<
Alignement
>
{},
number
<
1
>
{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr
auto
k_lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
Alignement
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
k_lds_block_desc_issues_warps_lanes
;
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
kMPerBlock
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
kMPerBlock
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
kKPerBlock
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
A
SmemLoadTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeSmemLoadTileDistribution
_A
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignmentA
<
Problem
>
();
constexpr
index_t
KPack
=
GetSmemKPackA
<
Problem
>
();
constexpr
index_t
NumPrefetch
=
Problem
::
Traits
::
NumPrefetchA
;
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
return
MakeSmemLoadTileDescriptor_SimpleMxK_Async
<
kMPerBlock
,
kKPerBlock
,
NumWarps
,
Alignment
,
KPack
,
NumPrefetch
>
();
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
kVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
kPad
=
KPack
;
// pad between warps
static_assert
(
kKPerBlock
%
kVector
==
0
);
constexpr
index_t
LanesPerK
=
kKPerBlock
/
kVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
kMPerBlock
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
kMPerBlock
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
kPad
)
>
{},
// m0
number
<
kKPerBlock
>
{},
// m1
number
<
warpSize
*
KVector
+
kPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeASmemStoreTileDistribution
()
...
...
@@ -480,8 +532,8 @@ struct FusedMoePipelinePolicy
constexpr
index_t
kMPerBlock
=
Problem
::
FusedMoeTileShape
::
kM_a
;
constexpr
index_t
kKPerBlock
=
Problem
::
FusedMoeTileShape
::
kK_a
;
constexpr
index_t
NumWarps
=
Problem
::
FusedMoeTileShape
::
NumWarps
;
constexpr
index_t
Alignment
=
GetAlignmentA
<
Problem
>
();
constexpr
index_t
KPack
=
GetSmemKPackA
<
Problem
>
();
constexpr
index_t
Alignment
=
GetAlignment
_
A
<
Problem
>
();
constexpr
index_t
KPack
=
GetSmemKPack
_
A
<
Problem
>
();
constexpr
index_t
NumPrefetch
=
Problem
::
Traits
::
NumPrefetchA
;
return
MakeSmemStoreBlockDescriptor_SimpleMxK_Async
<
kMperBlock
,
...
...
@@ -492,13 +544,14 @@ struct FusedMoePipelinePolicy
Alignment
>
();
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGSmemLoadTileDistribution()
{
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr
index_t
Alignment
=
GetAlignmentG
<
Problem
>
();
constexpr index_t Alignment = GetAlignment
_
G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
...
...
@@ -515,7 +568,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_g;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr
index_t
Alignment
=
GetAlignmentG
<
Problem
>
();
constexpr index_t Alignment = GetAlignment
_
G<Problem>();
constexpr index_t KPack = GetSmemKPackG<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchG;
...
...
@@ -533,7 +586,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_u;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_a;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr
index_t
Alignment
=
GetAlignmentU
<
Problem
>
();
constexpr index_t Alignment = GetAlignment
_
U<Problem>();
constexpr index_t KPack = GetSmemKPackU<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchU;
...
...
@@ -551,7 +604,7 @@ struct FusedMoePipelinePolicy
constexpr index_t kNPerBlock = Problem::FusedMoeTileShape::kN_d;
constexpr index_t kKPerBlock = Problem::FusedMoeTileShape::kK_y;
constexpr index_t NumWarps = Problem::FusedMoeTileShape::NumWarps;
constexpr
index_t
Alignment
=
GetAlignmentD
<
Problem
>
();
constexpr index_t Alignment = GetAlignment
_
D<Problem>();
constexpr index_t KPack = GetSmemKPackD<Problem>();
constexpr index_t NumPrefetch = Problem::Traits::NumPrefetchD;
...
...
@@ -562,7 +615,32 @@ struct FusedMoePipelinePolicy
KPack,
NumPrefetch>();
}
#endif
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm0
()
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
GDataType
,
typename
Problem
::
AccDataType
,
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm0WarpTile
::
at
(
number
<
2
>
{}),
true
/*TransposeC*/
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm1
()
{
return
WarpGemmMfmaDispatcher
<
typename
Problem
::
YDataType
,
typename
Problem
::
DDataType
,
typename
Problem
::
AccDataType
,
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
0
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
1
>
{}),
Problem
::
FusedMoeTileShape
::
Gemm1WarpTile
::
at
(
number
<
2
>
{}),
true
/*TransposeC*/
>
{};
}
#if 0
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetGemm0()
{
...
...
@@ -628,5 +706,6 @@ struct FusedMoePipelinePolicy
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
#endif
};
}
// namespace ck_tile
example/ck_tile/05_moe/fused_moe/pipeline/fused_moe_traits.hpp
View file @
1ba8a08f
...
...
@@ -12,21 +12,22 @@
namespace
ck_tile
{
template
<
bool
GateUpPreShuffled_
=
false
,
bool
DownPreShuffled_
=
false
,
index_t
NumPrefetchA_
=
2
,
index_t
NumPrefetchG_
=
2
,
index_t
NumPrefetchU_
=
2
,
index_t
NumPrefetchD_
=
2
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
enum
class
FusedMoePermuteStyle
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
no_permute
=
999
,
};
template
<
bool
DownPreShuffled_
=
false
,
FusedMoePermuteStyle
PermuteStyle_
=
FusedMoePermuteStyle
::
permute_b_nr_kr_kw_nw_kv
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
FusedMoeTraits
{
static
constexpr
bool
GateUpPreShuffled
=
GateUpPreShuffled_
;
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
index_t
NumPrefetchA
=
NumPrefetchA_
;
static
constexpr
index_t
NumPrefetchG
=
NumPrefetchG_
;
static
constexpr
index_t
NumPrefetchU
=
NumPrefetchU_
;
static
constexpr
index_t
NumPrefetchD
=
NumPrefetchD_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
bool
DownPreShuffled
=
DownPreShuffled_
;
static
constexpr
FusedMoePermuteStyle
PermuteStyle
=
PermuteStyle_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment