Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1f9546e0
Commit
1f9546e0
authored
Dec 12, 2024
by
root
Browse files
Merge branch 'develop' into gemm_bf16_sk_muozturk
parents
78394194
86990558
Changes
472
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3128 additions
and
268 deletions
+3128
-268
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+73
-54
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
+15
-6
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+12
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+17
-16
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+32
-24
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+11
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+14
-12
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+96
-134
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+26
-5
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+1
-1
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+19
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
...ude/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
+421
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
+125
-0
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
...e/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
+33
-0
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+303
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp
+651
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
...sed_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
+831
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
...s/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
+354
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
...ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
+46
-0
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
...e/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
+48
-0
No files found.
Too many changes to show.
To preserve performance only
472 of 472+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
1f9546e0
...
@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -25,6 +25,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
...
@@ -34,12 +35,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -34,12 +35,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
@@ -47,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -47,7 +49,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
true
;
// always s
tore
LSE
(acc)
static
constexpr
bool
kStoreLSE
=
Problem
::
kS
toreLSE
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
...
@@ -64,6 +66,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -64,6 +66,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
}();
static
constexpr
index_t
kAlignmentOacc
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOacc
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
@@ -72,22 +77,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -72,22 +77,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
Problem
::
kBlockPerCu
;
return
Problem
::
kBlockPerCu
;
else
else
{
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
{
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
{
return
3
;
return
3
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
{
return
1
;
return
1
;
}
}
...
@@ -138,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -138,6 +143,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
static_assert
(
static_assert
(
...
@@ -206,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -206,16 +212,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
const
auto
[
logical_
seqlen_k_start
,
logical_
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if no work to do
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
{
const
index_t
o
ri
gi
n
al_num_total_loop
=
const
index_t
l
ogi
c
al_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
integer_divide_ceil
(
logical_
seqlen_k_end
-
logical_
seqlen_k_start
,
kN0
);
if
(
o
ri
gi
n
al_num_total_loop
<=
0
)
if
(
l
ogi
c
al_num_total_loop
<=
0
)
{
{
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
...
@@ -234,40 +240,48 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -234,40 +240,48 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}
}
}
}
// make sure the first tile is completely located in page-block
const
index_t
physical_seqlen_k_start
=
logical_seqlen_k_start
+
kv_l2p_offset
;
const
index_t
adjusted_seqlen_k_start
=
[
&
,
seqlen_k_start_
=
seqlen_k_start
]
{
const
index_t
physical_seqlen_k_end
=
logical_seqlen_k_end
+
kv_l2p_offset
;
if
constexpr
(
kIsPagedKV
)
// make sure the first tile is completely located in page-block (page-block size should be
{
// divisible by kN0)
return
kN0
*
integer_divide_floor
(
seqlen_k_start_
,
kN0
);
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
}
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
else
const
index_t
aligned_physical_seqlen_k_start
=
{
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
]
{
return
seqlen_k_start_
;
if
constexpr
(
kIsPagedKV
)
}
{
}();
return
kN0
*
integer_divide_floor
(
physical_seqlen_k_start_
,
kN0
);
}
else
{
return
physical_seqlen_k_start_
;
}
}();
const
index_t
num_total_loop
=
const
index_t
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
a
djusted
_seqlen_k_start
,
kN0
);
integer_divide_ceil
(
physical_
seqlen_k_end
-
a
ligned_physical
_seqlen_k_start
,
kN0
);
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
k_dram_block_window_lengths
,
{
a
djusted
_seqlen_k_start
,
0
});
k_dram_block_window_lengths
,
{
a
ligned_physical
_seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
auto
bias_dram_window
=
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
adjusted_seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
logical_seqlen_k_start
-
(
physical_seqlen_k_start
-
aligned_physical_seqlen_k_start
)},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
v_dram_block_window_lengths
,
{
0
,
a
djusted
_seqlen_k_start
},
// TODO: hdim split?
{
0
,
a
ligned_physical
_seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
// prefetch K tile
// prefetch K tile
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
2
<=
k0_loops
);
...
@@ -374,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -374,7 +388,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
// position_encoding accept only logical coordinates, do conversion here
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
-
kv_l2p_offset
);
});
});
});
});
}
}
...
@@ -392,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -392,29 +407,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
set_tile_if
(
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
s_acc
,
[
&
,
seqlen_k_start_
=
seqlen_k_start
,
seqlen_k_end_
=
seqlen_k_end
](
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
auto
tile_idx
)
{
[
&
,
const
auto
col
=
physical_seqlen_k_start_
=
physical_seqlen_k_start
,
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
if
constexpr
(
kIsPagedKV
)
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
{
if
constexpr
(
kIsPagedKV
)
return
col
<
seqlen_k_start_
||
seqlen_k_end_
<=
col
;
{
}
return
col
<
physical_seqlen_k_start_
||
physical_seqlen_k_end_
<=
col
;
else
}
{
else
return
seqlen_k_end_
<=
col
;
{
}
return
physical_seqlen_k_end_
<=
col
;
});
}
});
}
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
// mask accept only logical coordinates, do conversion here
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{})
-
kv_l2p_offset
,
number
<
kM0
>
{},
number
<
kM0
>
{},
number
<
kN0
>
{});
number
<
kN0
>
{});
if
(
need_perpixel_check
)
if
(
need_perpixel_check
)
...
@@ -423,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -423,7 +440,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
);
return
mask
.
IsOutOfBound
(
row
,
col
-
kv_l2p_offset
);
});
});
}
}
}
}
...
@@ -654,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -654,6 +671,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
...
@@ -676,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -676,6 +694,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask
,
mask
,
position_encoding
,
position_encoding
,
scale_s
,
scale_s
,
kv_l2p_offset
,
smem_ptr
);
smem_ptr
);
}
}
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
View file @
1f9546e0
...
@@ -9,11 +9,20 @@
...
@@ -9,11 +9,20 @@
namespace
ck_tile
{
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
using
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
=
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
/* NumPrefetchV = */
1
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOacc
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
OaccDataType
));
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
1f9546e0
...
@@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem
...
@@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kNumGemm0Warps
=
BlockFmhaShape
::
NumGemm0Warps
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kNumGemm1Warps
=
BlockFmhaShape
::
NumGemm1Warps
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
@@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem
...
@@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kNumGemm0Warps
=
BlockFmhaShape
::
NumGemm0Warps
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kNumGemm1Warps
=
BlockFmhaShape
::
NumGemm1Warps
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
@@ -115,7 +121,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
...
@@ -115,7 +121,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
256
;
static
constexpr
index_t
kNumWarps
=
kM0_
/
(
get_warp_size
()
/
4
);
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
1f9546e0
...
@@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
return
Problem
::
kBlockPerCu
;
return
Problem
::
kBlockPerCu
;
else
else
{
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
{
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
{
return
3
;
return
3
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
{
return
1
;
return
1
;
}
}
...
@@ -242,11 +243,11 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -242,11 +243,11 @@ struct BlockFmhaPipelineQRKSVS
{
seqlen_k_start
,
0
});
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
auto
bias_dram_window
=
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
randval_dram_block_window_tmp
,
seqlen_k_start
);
...
@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
// prefetch K tile
// prefetch K tile
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
1f9546e0
...
@@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...
@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
return
1
;
return
1
;
}
}
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
FmhaMask
::
IsMasking
)
...
@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
else
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
return
2
;
else
else
return
3
;
return
3
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
{
return
1
;
return
1
;
}
}
...
@@ -314,11 +315,11 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -314,11 +315,11 @@ struct BlockFmhaPipelineQRKSVSAsync
}();
}();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
auto
bias_dram_window
=
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
randval_dram_block_window_tmp
,
seqlen_k_start
);
...
@@ -330,16 +331,17 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -330,16 +331,17 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
(
k_dram_window
.
get_num_
of_
access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
static_assert
(
1
<=
k0_loops
);
...
@@ -354,12 +356,13 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -354,12 +356,13 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
,
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_oob_ck
,
k_pre_np
);
k_pre_np
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
(
k_dram_window
.
get_num_
of_
access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
...
@@ -385,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -385,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
auto
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
auto
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
{
// tail
{
// tail
gemm_0
(
s_acc
,
gemm_0
(
s_acc
,
...
@@ -513,7 +516,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -513,7 +516,8 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
move_tile_window
(
v_dram_window
,
v_dram_window
,
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
{
0
,
kK1
});
// will have scratch if move this right after load_tile(v_dram)...
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -617,7 +621,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -617,7 +621,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
if
constexpr
(
i_k1
!=
0
&&
i_k1
<
k1_loops
-
1
)
{
{
v_buf
=
load_tile
(
v_dram_window
,
bool_constant
<
false
>
{});
// load next v_buf
v_buf
=
load_tile
(
v_dram_window
,
number
<-
1
>
{},
bool_constant
<
false
>
{});
// load next v_buf
}
}
block_sync_lds
();
block_sync_lds
();
gemm_1
(
o_acc
,
gemm_1
(
o_acc
,
...
@@ -664,8 +669,11 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -664,8 +669,11 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
k1_loops
>=
2
&&
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
k_dram_window
,
number
<-
1
>
{},
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
}
// tail
// tail
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
1f9546e0
...
@@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
k
K0BlockLength
=
BlockFmhaShape
::
k
K0BlockLength
;
static
constexpr
index_t
k
QKHeaddim
=
BlockFmhaShape
::
k
QKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
return
Problem
::
kBlockPerCu
;
return
Problem
::
kBlockPerCu
;
else
else
{
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
{
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
{
return
3
;
return
3
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
{
return
1
;
return
1
;
}
}
...
@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
// prefetch K tile
// prefetch K tile
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
1f9546e0
...
@@ -9,9 +9,10 @@
...
@@ -9,9 +9,10 @@
namespace
ck_tile
{
namespace
ck_tile
{
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
BlockFmhaPipelineQSKSVS
struct
[[
deprecated
]]
BlockFmhaPipelineQSKSVS
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
@@ -35,12 +36,13 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -35,12 +36,13 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
@@ -55,22 +57,22 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -55,22 +57,22 @@ struct BlockFmhaPipelineQSKSVS
return
Problem
::
kBlockPerCu
;
return
Problem
::
kBlockPerCu
;
else
else
{
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
{
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
{
return
3
;
return
3
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
{
return
1
;
return
1
;
}
}
...
@@ -234,7 +236,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -234,7 +236,7 @@ struct BlockFmhaPipelineQSKSVS
// prefetch K tile
// prefetch K tile
index_t
i_total_loops
=
0
;
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
1f9546e0
...
@@ -5,9 +5,9 @@
...
@@ -5,9 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
// TODO: remove this
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
...
@@ -54,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -54,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0BlockLength
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
SubQKHeaddim
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
@@ -64,51 +65,72 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -64,51 +65,72 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
tile_distribution_encoding
<
sequence
<
1
>
,
{
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
return
make_static_tile_distribution
(
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
sequence
<
1
,
2
,
2
>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
sequence
<
0
,
0
,
2
>>
{});
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
MWarp
==
1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
GemmProblem
=
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
kNumGemm0Warps
*
get_warp_size
(),
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kN0
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
}
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
{
static_assert
(
WarpGemmM
==
32
);
// TODO: hard coded here. Otherwise, it may incorrect result
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
...
@@ -123,12 +145,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -123,12 +145,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
else
return
BlockGemmARegBSmemCRegOneWarpV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
}
};
};
/// NOTICE: we no-longer use this policy.
template
<
>
template
<
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
struct
[[
deprecated
]]
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
{
static
constexpr
bool
QLoadOnce
=
false
;
static
constexpr
bool
QLoadOnce
=
false
;
...
@@ -207,20 +233,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
...
@@ -207,20 +233,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
{
using
GemmProblem
=
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
kBlockSize
,
Problem
::
BlockFmhaShape
::
kN0
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
Problem
::
BlockFmhaShape
::
kN0
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
@@ -302,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -302,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
>
struct
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
3
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
LdsBufferSequence
<
3
,
3
,
3
,
3
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
2
>
{
using
type
=
sequence
<
1
,
2
,
1
,
0
>
;};
LdsBufferSequence
<
3
,
3
,
2
,
2
>
{
using
type
=
sequence
<
1
,
2
,
1
,
0
>
;};
// clang-format on
// clang-format on
...
@@ -311,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -311,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
k
K0BlockLength
=
BlockFmhaShape
::
k
K0BlockLength
;
constexpr
index_t
k
QKHeaddim
=
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
return
typename
LdsBufferSequence
<
NumPrefetchK
,
NumPrefetchV
,
k0_loops
,
k1_loops
>::
type
{};
return
typename
LdsBufferSequence
<
NumPrefetchK
,
NumPrefetchV
,
k0_loops
,
k1_loops
>::
type
{};
...
@@ -363,12 +387,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -363,12 +387,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kMaxVecLoad
=
min
(
total_pixels
,
static_cast
<
index_t
>
(
16
/
sizeof
(
VDataType
)));
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
VDataType
);
// TODO: not correct!
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
if
constexpr
(
total_pixels
>
4
)
?
kMaxVecLoad
return
4
;
:
(
total_pixels
/
kMinVecLoad
)
;
else
return
2
;
return
kVecLoad
;
}
}
else
else
{
{
...
@@ -382,10 +409,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -382,10 +409,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using
BlockGemm
=
remove_cvref_t
<
decltype
(
QXPolicy
::
template
GetQKBlockGemm
<
Problem
>())
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
QXPolicy
::
template
GetQKBlockGemm
<
Problem
>())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -394,10 +419,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -394,10 +419,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetKVBlockGemm
<
Problem
>
())
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetKVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -448,44 +471,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -448,44 +471,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
max
(
SingleKSize
,
SingleVSize
);
return
max
(
SingleKSize
,
SingleVSize
);
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
auto
q_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
// TODO: this is used for non async copy desc. unify in the future
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK
1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK
0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
...
@@ -872,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -872,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
!=
0
);
static_assert
(
N0
!=
0
);
...
@@ -885,36 +878,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -885,36 +878,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
}
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasDramTileDistribution
()
{
{
constexpr
index_t
MPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
return
BlockGemm
::
MakeCBlockTile
().
get_tile_distribution
();
constexpr
index_t
NPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// Construct C-Block-HostTensor
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -968,20 +935,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -968,20 +935,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
{
using
GemmProblem
=
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
PDataType
,
BlockGemmProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
kNumGemm1Warps
*
get_warp_size
(),
Problem
::
BlockFmhaShape
::
kN1
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kK1
>
,
Problem
::
BlockFmhaShape
::
kN1
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
auto
warp_gemm
=
[
&
]()
{
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -7,6 +7,20 @@
...
@@ -7,6 +7,20 @@
namespace
ck_tile
{
namespace
ck_tile
{
static
CK_TILE_HOST_DEVICE
constexpr
index_t
ceil_to_qualified_tile_length
(
index_t
len
)
{
if
(
len
==
96
)
return
128
;
if
(
len
==
160
)
return
256
;
// only length of 96, 160 and power-of-two is supported
if
(
!
(
len
&
(
len
-
1
)))
return
len
;
return
0
;
};
template
<
typename
BlockTile_
,
// sequence<...
template
<
typename
BlockTile_
,
// sequence<...
typename
Gemm0BlockWarps_
,
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
typename
Gemm0WarpTile_
,
...
@@ -21,20 +35,27 @@ struct TileFmhaShape
...
@@ -21,20 +35,27 @@ struct TileFmhaShape
using
Gemm1BlockWarps
=
remove_cvref_t
<
Gemm1BlockWarps_
>
;
using
Gemm1BlockWarps
=
remove_cvref_t
<
Gemm1BlockWarps_
>
;
using
Gemm1WarpTile
=
remove_cvref_t
<
Gemm1WarpTile_
>
;
using
Gemm1WarpTile
=
remove_cvref_t
<
Gemm1WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
static
constexpr
index_t
Num
Gemm0
Warps
=
reduce_on_sequence
(
Gemm0BlockWarps
{},
multiplies
{},
number
<
1
>
{});
reduce_on_sequence
(
Gemm0BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static
constexpr
index_t
NumGemm1Warps
=
reduce_on_sequence
(
Gemm1BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
NumGemm1Warps
%
NumGemm0Warps
==
0
);
static
constexpr
index_t
NumWarps
=
max
(
NumGemm0Warps
,
NumGemm1Warps
);
static_assert
(
NumWarps
==
reduce_on_sequence
(
Gemm1BlockWarps
{},
multiplies
{},
number
<
1
>
{})
);
static_assert
(
std
::
is_same_v
<
Gemm0WarpTile
,
Gemm1WarpTile
>
);
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
static
constexpr
index_t
kN1
=
BlockTile
::
at
(
number
<
3
>
{});
// tile size along v head_dim
static
constexpr
index_t
kN1
=
BlockTile
::
at
(
number
<
3
>
{});
// tile size along v head_dim
static
constexpr
index_t
kK1
=
BlockTile
::
at
(
number
<
4
>
{});
// tile size along kv gemm unroll
static
constexpr
index_t
kK1
=
BlockTile
::
at
(
number
<
4
>
{});
// tile size along kv gemm unroll
static
constexpr
index_t
k
K0BlockLength
=
static
constexpr
index_t
k
QKHeaddim
=
BlockTile
::
at
(
number
<
5
>
{});
// total length of K0, used for pipeline that need load Q at
BlockTile
::
at
(
number
<
5
>
{});
// total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
// once (or repeately load Q as a whole tile)
static_assert
(
kK0BlockLength
%
kK0
==
0
,
"kK0BlockLength should be divisible by kK0"
);
static_assert
(
kQKHeaddim
%
kK0
==
0
,
"kQKHeaddim should be divisible by kK0"
);
static
constexpr
index_t
kSubQKHeaddim
=
ceil_to_qualified_tile_length
(
kQKHeaddim
);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static
constexpr
bool
IsVLayoutRowMajor
=
IsVLayoutRowMajor_
;
static
constexpr
bool
IsVLayoutRowMajor
=
IsVLayoutRowMajor_
;
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
1f9546e0
...
@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...
@@ -39,7 +39,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
BlockAttentionBiasEnum
BiasEnum_
,
BlockAttentionBiasEnum
BiasEnum_
,
bool
kHasBiasGrad_
,
bool
kHasBiasGrad_
,
bool
kStoreLSE_
,
bool
kStoreLSE_
,
/* set to true if either num_splits > 1 or fwd training is running */
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kIsPagedKV_
,
bool
kHasUnevenSplits_
,
bool
kHasUnevenSplits_
,
...
...
include/ck_tile/ops/fused_moe.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace
ck_tile
{
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct
FusedMoeGemmHostArgs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
// [max_num_tokens_padded]
const
void
*
sorted_weight_ptr
;
// [max_num_tokens_padded]
const
void
*
sorted_expert_ids_ptr
;
// [(max_num_tokens_padded + block_size - 1) / block_size]
const
void
*
num_sorted_tiles_ptr
;
// [1]
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template
<
typename
Partitioner_
,
typename
Pipeline_
,
typename
Epilogue_
>
struct
FusedMoeGemmKernel
{
using
Partitioner
=
remove_cvref_t
<
Partitioner_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
// TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using
BlockShape
=
typename
Pipeline
::
BlockShape
;
// this is FusedMoeGemmShape
static
constexpr
index_t
BlockSize_
=
BlockShape
::
BlockSize
;
using
ADataType
=
typename
Pipeline
::
Problem
::
ADataType
;
using
GDataType
=
typename
Pipeline
::
Problem
::
GDataType
;
using
DDataType
=
typename
Pipeline
::
Problem
::
DDataType
;
using
AccDataType
=
typename
Pipeline
::
Problem
::
AccDataType
;
using
ODataType
=
typename
Pipeline
::
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Pipeline
::
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Pipeline
::
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Pipeline
::
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Pipeline
::
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Pipeline
::
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Pipeline
::
Problem
::
IndexDataType
;
using
YDataType
=
typename
Pipeline
::
Problem
::
YDataType
;
using
Traits
=
typename
Pipeline
::
Problem
::
Traits
;
static
constexpr
bool
UseUK
=
true
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
S_
=
BlockShape
;
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
ADataType
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType
,
GDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
GDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"fused_moe_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M0
)
+
"x"
+
_TS_
(
S_
::
Block_N0
)
+
"x"
+
_TS_
(
S_
::
Block_K0
)
+
"x"
+
_TS_
(
S_
::
Block_N1
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N0
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_K0
)
+
"_"
+
_TS_
(
S_
::
Warp_M0
)
+
"x"
+
_TS_
(
S_
::
Warp_N0
)
+
"x"
+
_TS_
(
S_
::
Warp_K0
)
+
"_"
+
_SS_
(
Pipeline
::
name
);
#undef _SS_
#undef _TS_
// clang-format on
}
struct
FusedMoeGemmKargs
{
const
void
*
a_ptr
;
// [m, k], input token
const
void
*
a_scale_ptr
;
// [m, 1], token scale
const
void
*
g_ptr
;
// [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const
void
*
d_ptr
;
// [e, n, k], pre-shuffle([e, nr, kr, w])
const
void
*
g_scale_ptr
;
// [e, 1, n], gate(up) scale
const
void
*
d_scale_ptr
;
// [e, 1, k], down scale
const
void
*
y_smooth_scale_ptr
;
// [e, 1, n], smooth-quant-scale for 2nd gemm input
void
*
o_ptr
;
// [m, k], output token
const
void
*
sorted_token_ids_ptr
;
const
void
*
sorted_weight_ptr
;
const
void
*
sorted_expert_ids_ptr
;
const
void
*
num_sorted_tiles_ptr
;
index_t
hidden_size
;
// k
index_t
intermediate_size
;
// n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t
num_tokens
;
// input number of tokens for current iteration
index_t
num_experts
;
// number of groups
index_t
topk
;
// need this?
index_t
stride_token
;
// for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using
Kargs
=
FusedMoeGemmKargs
;
using
Hargs
=
FusedMoeGemmHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
// TODO: hargs/kargs not guranteed to be the same
return
bit_cast
<
Kargs
>
(
hargs
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
constexpr
index_t
block_m
=
BlockShape
::
Block_M0
;
int
max_num_tokens_padded
=
hargs
.
topk
*
hargs
.
num_tokens
+
hargs
.
num_experts
*
block_m
-
hargs
.
topk
;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return
Partitioner
::
GridSize
(
max_num_tokens_padded
,
hargs
.
intermediate_size
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
BlockSize_
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
if
constexpr
(
UseUK
)
{
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
num_sorted_tiles
=
num_sorted_tiles
/
BlockShape
::
Block_M0
;
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
// if(threadIdx.x == 0)
// printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
// intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
// num_sorted_tiles? 1 : 0, intermediate_tile_id);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
Pipeline
{}(
kargs
,
smem
,
sorted_tile_id
,
intermediate_tile_id
);
}
else
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType
num_sorted_tiles
=
__builtin_amdgcn_readfirstlane
(
*
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
num_sorted_tiles_ptr
));
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
index_t
nr_0
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Block_Kr0
;
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Block_Nr1
;
// should be same as kr_0
index_t
kr_1
=
kargs
.
intermediate_size
/
BlockShape
::
Block_Kr1
;
// should be same as nr_0
index_t
expert_stride_0
=
kargs
.
intermediate_size
*
hidden_radio_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
kargs
.
intermediate_size
*
kargs
.
hidden_size
;
__shared__
CK_TILE_LDS_ADDR
ADataType
smem
[
GetSmemSize
()];
// note this is in unit of tile, need multiple tile size to get the index
const
auto
[
sorted_tile_id
,
intermediate_tile_id
]
=
Partitioner
{}(
num_sorted_tiles
,
kargs
.
intermediate_size
);
if
(
sorted_tile_id
>=
num_sorted_tiles
)
return
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t
interm_idx_nr
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
const
auto
a_coord
=
Pipeline
::
GetACoord
();
// 2d thread offset, [i_row, i_col]
const
auto
sorted_token_id
=
a_coord
[
number
<
0
>
{}]
+
sorted_tile_id
*
BlockShape
::
Block_M0
;
index_t
token_id
=
reinterpret_cast
<
const
index_t
*>
(
kargs
.
sorted_token_ids_ptr
)[
sorted_token_id
];
auto
topk_weight
=
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
)[
sorted_token_id
];
const
auto
a_window
=
[
&
]()
{
// A is already pre-padded in previous kernel
const
ADataType
*
a_ptr
=
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
auto
a_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentA
>
{},
number
<
1
>
{});
// gather is here use indexing transform
const
auto
a_gather_view_
=
transform_tensor_view
(
a_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
a_window_
=
make_tile_window
(
a_gather_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M0
>
{},
number
<
BlockShape
::
Block_K0
>
{}),
{
0
,
0
});
return
a_window_
;
}();
// TODO: gtile using NSub to have less register pressure
const
auto
g_window
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr
*
kr_0
*
BlockShape
::
Block_W0
;
const
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
Pipeline
::
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
g_view_1_
=
pad_tensor_view
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
const
auto
g_window_
=
make_tile_window
(
g_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
});
return
g_window_
;
}();
const
auto
d_window
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_nr
*
BlockShape
::
Block_W1
;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
Pipeline
::
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_view_1_
=
pad_tensor_view
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
sequence
<
PadHiddenSize
,
PadIntermediateSize
,
0
>
{});
const
auto
d_window_
=
make_tile_window
(
d_view_1_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
});
return
d_window_
;
}();
auto
o_window
=
[
&
]()
{
ODataType
*
o_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_ptr
);
auto
o_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
o_ptr
,
make_tuple
(
kargs
.
num_tokens
,
kargs
.
hidden_size
),
make_tuple
(
kargs
.
stride_token
,
1
),
number
<
Pipeline
::
kAlignmentO
>
{},
number
<
1
>
{});
// gather is here
auto
o_scatter_view_
=
transform_tensor_view
(
o_view_
,
make_tuple
(
make_indexing_transform
(
kargs
.
num_tokens
,
token_id
),
make_pass_through_transform
(
kargs
.
hidden_size
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
o_window_
=
make_tile_window
(
o_scatter_view_
,
make_tuple
(
number
<
BlockShape
::
Block_M1
>
{},
number
<
BlockShape
::
Block_N1
>
{}),
{
0
,
0
});
return
o_window_
;
}();
// do compute yeah
Pipeline
{}(
a_window
,
g_window
,
d_window
,
o_window
,
topk_weight
,
smem
,
kargs
.
hidden_size
,
kargs
.
intermediate_size
,
kargs
.
stride_token
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N1
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K1 | | |
N0 N0 x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K0 | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M0 | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
----------> | | |
contiguous | V V
| x-----x +----------+
+------------> M1 | Y | ---------> | Out(O) |
ACT x-----x +----------+
K1 = N0 dim
* Note: Act could be Gelu/Silu/...
* Note: some model does not have Up
*/
template
<
typename
BlockTile_0_
,
typename
WarpPerBlock_0_
,
typename
WarpTile_0_
,
typename
BlockTile_1_
,
typename
WarpPerBlock_1_
,
typename
WarpTile_1_
>
struct
FusedMoeGemmShape
{
using
BlockTile_0
=
remove_cvref_t
<
BlockTile_0_
>
;
using
WarpPerBlock_0
=
remove_cvref_t
<
WarpPerBlock_0_
>
;
using
WarpTile_0
=
remove_cvref_t
<
WarpTile_0_
>
;
using
BlockTile_1
=
remove_cvref_t
<
BlockTile_1_
>
;
using
WarpPerBlock_1
=
remove_cvref_t
<
WarpPerBlock_1_
>
;
using
WarpTile_1
=
remove_cvref_t
<
WarpTile_1_
>
;
static
constexpr
index_t
NumWarps
=
reduce_on_sequence
(
WarpPerBlock_0
{},
multiplies
{},
number
<
1
>
{});
// TODO: we don't support half warps aound to 1 warp here
static_assert
(
NumWarps
==
reduce_on_sequence
(
WarpPerBlock_1
{},
multiplies
{},
number
<
1
>
{}));
static
constexpr
index_t
Block_M0
=
BlockTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N0
=
BlockTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K0
=
BlockTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M0
=
WarpPerBlock_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N0
=
WarpPerBlock_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K0
=
WarpPerBlock_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_M0
=
WarpTile_0
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N0
=
WarpTile_0
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K0
=
WarpTile_0
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M0
=
Warp_M0
*
WarpPerBlock_M0
;
static
constexpr
index_t
ThreadPerBlock_N0
=
Warp_N0
*
WarpPerBlock_N0
;
static
constexpr
index_t
ThreadPerBlock_K0
=
Warp_K0
*
WarpPerBlock_K0
;
static_assert
(
Block_M0
%
ThreadPerBlock_M0
==
0
);
static_assert
(
Block_N0
%
ThreadPerBlock_N0
==
0
);
static_assert
(
Block_K0
%
ThreadPerBlock_K0
==
0
);
static
constexpr
index_t
Repeat_M0
=
Block_M0
/
ThreadPerBlock_M0
;
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
ThreadPerBlock_N0
;
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
ThreadPerBlock_K0
;
static
constexpr
index_t
Block_M1
=
BlockTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N1
=
BlockTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Block_K1
=
BlockTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
WarpPerBlock_M1
=
WarpPerBlock_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N1
=
WarpPerBlock_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
WarpPerBlock_K1
=
WarpPerBlock_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warp_M1
=
WarpTile_1
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N1
=
WarpTile_1
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warp_K1
=
WarpTile_1
::
at
(
number
<
2
>
{});
static
constexpr
index_t
ThreadPerBlock_M1
=
Warp_M1
*
WarpPerBlock_M1
;
static
constexpr
index_t
ThreadPerBlock_N1
=
Warp_N1
*
WarpPerBlock_N1
;
static
constexpr
index_t
ThreadPerBlock_K1
=
Warp_K1
*
WarpPerBlock_K1
;
static_assert
(
Block_M1
%
ThreadPerBlock_M1
==
0
);
static_assert
(
Block_N1
%
ThreadPerBlock_N1
==
0
);
static_assert
(
Block_K1
%
ThreadPerBlock_K1
==
0
);
static
constexpr
index_t
Repeat_M1
=
Block_M1
/
ThreadPerBlock_M1
;
static
constexpr
index_t
Repeat_N1
=
Block_N1
/
ThreadPerBlock_N1
;
static
constexpr
index_t
Repeat_K1
=
Block_K1
/
ThreadPerBlock_K1
;
static
constexpr
index_t
BlockSize
=
warpSize
*
NumWarps
;
// some assert
static_assert
(
Block_M0
==
Block_M1
);
static_assert
(
Block_N0
==
Block_K1
||
(
Block_N0
/
2
)
==
Block_K1
);
// Gate Only or Gate+Up
// pre-shuffle tile size compute (assume only for B matrix)
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
static
constexpr
index_t
Block_W0
=
Warp_N0
*
Warp_K0
;
static
constexpr
index_t
Block_Nr0
=
Block_N0
/
Warp_N0
;
static
constexpr
index_t
Block_Kr0
=
Block_K0
/
Warp_K0
;
static
constexpr
index_t
Block_W1
=
Warp_N1
*
Warp_K1
;
static
constexpr
index_t
Block_Nr1
=
Block_N1
/
Warp_N1
;
static
constexpr
index_t
Block_Kr1
=
Block_K1
/
Warp_K1
;
static_assert
(
Block_W0
==
Block_W1
);
// static_assert(Block_Nr0 == Block_Kr1);
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
template
<
typename
BlockShape_
>
struct
FusedMoeGemmTilePartitioner_Linear
{
// FusedMoeGemmShape
using
BlockShape
=
ck_tile
::
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
const
char
*
name
=
"lin"
;
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*num_sorted_tiles*/
,
ck_tile
::
index_t
/*intermediate_size*/
)
{
index_t
i_n
=
blockIdx
.
x
;
index_t
i_m
=
blockIdx
.
y
;
return
ck_tile
::
make_tuple
(
i_m
,
i_n
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
max_tokens
,
index_t
intermediate_size
)
{
// TODO: this may need tuning
index_t
ms
=
ck_tile
::
integer_divide_ceil
(
max_tokens
,
BlockShape
::
Block_M0
);
index_t
ns
=
ck_tile
::
integer_divide_ceil
(
intermediate_size
,
BlockShape
::
Block_N0
);
return
dim3
(
ns
,
ms
,
1
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [topk, token, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
struct
MoeSortingHostArgs
{
const
void
*
p_topk_ids
;
// [token, topk]
const
void
*
p_weights
;
// [token, topk]
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
void
*
p_total_tokens_post_pad
;
// we fused the setzero of output of fused-moe buffer
// set this pointer to nullptr will skip this operation
void
*
p_moe_buf
;
index_t
tokens
;
index_t
unit_size
;
// this is the M_a of fused-moe kernel
index_t
num_experts
;
index_t
topk
;
index_t
moe_buf_bytes
;
// byte size of p_moe_buf
};
template
<
typename
Problem_
>
struct
MoeSortingKernel
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
IndexType
=
typename
Problem
::
IndexType
;
using
WeightType
=
typename
Problem
::
WeightType
;
typedef
MoeSortingHostArgs
MoeSortingKargs
;
using
Hargs
=
MoeSortingHostArgs
;
struct
Kargs
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
void
*
p_total_tokens_post_pad
;
void
*
p_moe_buf
;
index_t
tokens
;
index_t
num_experts
;
index_t
moe_buf_bytes
;
index_t
tokens_per_thread
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
// TODO: assume num-experts not too much
return
dim3
(
1
+
ck_tile
::
integer_divide_ceil
(
h
.
moe_buf_bytes
,
BlockSize
(
h
).
x
*
16
));
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
}
// in byte
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
const
auto
blocks
=
BlockSize
(
h
);
return
((
blocks
.
x
+
1
)
*
h
.
num_experts
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_topk_ids
=
h
.
p_topk_ids
;
k
.
p_weights
=
h
.
p_weights
;
k
.
p_sorted_token_ids
=
h
.
p_sorted_token_ids
;
k
.
p_sorted_weights
=
h
.
p_sorted_weights
;
k
.
p_sorted_expert_ids
=
h
.
p_sorted_expert_ids
;
k
.
p_moe_buf
=
h
.
p_moe_buf
;
k
.
p_total_tokens_post_pad
=
h
.
p_total_tokens_post_pad
;
k
.
tokens
=
h
.
tokens
;
k
.
num_experts
=
h
.
num_experts
;
k
.
moe_buf_bytes
=
h
.
moe_buf_bytes
;
const
auto
blocks
=
BlockSize
(
h
);
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
unit_size_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
unit_size
)};
k
.
topk_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
topk
)};
return
k
;
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
return
row
*
total_col
+
col
;
}
CK_TILE_DEVICE
void
moe_buf_set_zero_kernel
(
uint8x16_t
*
buf
,
index_t
buf_bytes
)
const
{
const
index_t
offset
=
(
blockIdx
.
x
-
1
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
buf_bytes
/
16
)
{
buf
[
offset
]
=
uint8x16_t
{
0
};
}
}
CK_TILE_DEVICE
void
moe_align_block_size_kernel
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
index_t
*
p_sorted_token_ids
,
WeightType
*
p_sorted_weights
,
index_t
*
p_sorted_expert_ids
,
index_t
*
p_total_tokens_post_pad
,
const
index_t
num_experts
,
const
index_t
tokens_per_thread
,
const
index_t
numel
,
const
mdiv
unit_size_mdiv
,
const
mdiv
topk_mdiv
,
void
*
smem
)
const
{
const
index_t
tid
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
const
index_t
start_idx
=
tid
*
tokens_per_thread
;
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
tokens_cnts
[
calc_index
(
num_experts
,
0
,
tid
)]
=
0
;
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
i
,
tid
)]
+=
tokens_cnts
[
calc_index
(
num_experts
,
i
-
1
,
tid
)];
}
}
// __syncthreads();
if
(
tid
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
}
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t
curr_token_id
,
curr_topk_id
;
topk_mdiv
.
divmod
(
i
,
curr_token_id
,
curr_topk_id
);
p_sorted_token_ids
[
rank_post_pad
]
=
MOE_SORTING_MOCK_ID
(
curr_token_id
,
curr_topk_id
);
#else
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
#endif
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
}
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
tid
)];
while
(
expert_offset
<
cumsum
[
tid
+
1
])
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids
[
expert_offset
]
=
MOE_SORTING_MOCK_ID
(
prefill_token
,
topk_mdiv
.
divisor
);
#else
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
#endif
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
++
;
}
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
if
(
blockIdx
.
x
>
0
)
{
if
(
kargs
.
p_moe_buf
)
{
moe_buf_set_zero_kernel
(
reinterpret_cast
<
uint8x16_t
*>
(
kargs
.
p_moe_buf
),
kargs
.
moe_buf_bytes
);
}
return
;
}
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
extern
__shared__
char
smem
[];
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
p_sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
p_total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
tokens_per_thread
,
numel
,
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
smem
);
}
};
#undef MOE_SORTING_MOCK_ID
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoeGemmPipelineFlatmmPolicy
>
struct
FusedMoeGemmPipeline_FlatmmEx
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
DDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Problem
::
YDataType
;
using
Traits
=
typename
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
template
GetAlignment_A
<
Problem
>();
static
constexpr
index_t
kAlignmentG
=
Policy
::
template
GetAlignment_G
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"fused_moe_flatmm"
;
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
return
Policy
::
template
GetSmemSize_A
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
template
<
typename
AWindow
,
typename
GWindow
,
typename
DWindow
,
typename
OWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
GWindow
&
g_window_
,
const
DWindow
&
d_window_
,
OWindow
&
o_window_
,
TopkWeightDataType
/*topk_weight*/
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
hidden_size
,
index_t
intermediate_size
)
{
_Pragma
(
"clang diagnostic push"
)
_Pragma
(
"clang diagnostic ignored
\"
-Wc++20-extensions
\"
"
);
constexpr
auto
NEG1
=
number
<-
1
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
TRUE
=
bool_constant
<
true
>
{};
constexpr
auto
FALSE
=
bool_constant
<
false
>
{};
CK_TILE_LDS_ADDR
ADataType
*
smem_0
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
);
CK_TILE_LDS_ADDR
ADataType
*
smem_1
=
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
char
*>
(
smem
)
+
Policy
::
template
GetSmemSize_A
<
Problem
>());
auto
g_view
=
g_window_
.
get_bottom_tensor_view
();
auto
u_view
=
[
&
]()
{
if
constexpr
(
IsGateOnly
)
{
return
g_view
;
}
else
{
index_t
nr_0
=
intermediate_size
/
BlockShape
::
Block_Nr0
;
index_t
kr_0
=
hidden_size
/
BlockShape
::
Block_Kr0
;
const
GDataType
*
g_ptr
=
g_window_
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
const
GDataType
*
u_ptr
=
g_ptr
+
(
nr_0
/
2
)
*
kr_0
*
number
<
BlockShape
::
Block_W0
>
{};
const
auto
u_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
u_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
const
auto
u_view_1_
=
pad_tensor_view
(
u_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
sequence
<
PadIntermediateSize
,
PadHiddenSize
,
0
>
{});
return
u_view_1_
;
}
}();
auto
a_win
=
make_tile_window_linear
(
a_window_
,
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>());
auto
g_win
=
make_tile_window_linear
(
g_window_
,
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
auto
d_win
=
make_tile_window_linear
(
d_window_
,
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
auto
o_win
=
make_tile_window_linear
(
o_window_
,
Policy
::
template
MakeGlobalTileDistribution_O
<
Problem
>());
using
g_thread_type
=
decltype
(
load_tile
(
g_win
));
using
d_thread_type
=
decltype
(
load_tile
(
d_win
));
using
WarpGemm0
=
decltype
(
Policy
::
template
GetWarpGemm0
<
Problem
>());
using
WarpGemm1
=
decltype
(
Policy
::
template
GetWarpGemm1
<
Problem
>());
auto
warp_gemm_0
=
WarpGemm0
{};
auto
warp_gemm_1
=
WarpGemm1
{};
// issues_warps_lanes
auto
a_sst_win0
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
auto
a_sst_win1
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsStoreDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
,
0
});
// m*k
auto
a_sld_win0
=
[
&
]()
{
using
WG
=
WarpGemm0
;
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
sequence
<
BlockShape
::
Repeat_K0
>>
,
tuple
<
sequence
<
1
>>
,
tuple
<
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_0
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
// m*k
auto
a_sld_win1
=
[
&
]()
{
using
WG
=
WarpGemm0
;
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
BlockShape
::
Repeat_M0
,
BlockShape
::
WarpPerBlock_M0
>
,
sequence
<
BlockShape
::
Repeat_K0
>>
,
tuple
<
sequence
<
1
>>
,
tuple
<
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
typename
WG
::
AWarpDstrEncoding
{});
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
smem_1
,
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>()),
Policy
::
template
MakeLdsLoadDesc_A
<
Problem
>().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
auto
bridge_sst_win
=
[
&
]()
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
Policy
::
template
MakeBridgeLdsStoreDesc
<
Problem
>()),
Policy
::
template
MakeBridgeLdsStoreDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
auto
bridge_sld_win
=
[
&
]()
{
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
Policy
::
template
MakeBridgeLdsLoadDesc
<
Problem
>()),
Policy
::
template
MakeBridgeLdsLoadDesc
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeYTileDistribution
<
Problem
>());
}();
// also OK with C array, 2 register buffer
statically_indexed_array
<
g_thread_type
,
2
>
gs
;
constexpr
auto
issues_a
=
number
<
a_win
.
get_num_of_access
()
>
{};
constexpr
auto
issues_g
=
number
<
g_win
.
get_num_of_access
()
>
{};
// constexpr auto issues_d = number<d_win.get_num_of_access()>{};
// constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr
auto
issues_gemm0
=
number
<
BlockShape
::
Repeat_M0
*
BlockShape
::
Repeat_N0
*
BlockShape
::
Repeat_K0
*
warp_gemm_0
.
get_num_of_access
()
>
{};
constexpr
auto
issues_gemm1
=
number
<
BlockShape
::
Repeat_M1
*
BlockShape
::
Repeat_N1
*
BlockShape
::
Repeat_K1
*
warp_gemm_1
.
get_num_of_access
()
>
{};
// constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const
index_t
num_blocks_k0
=
(
hidden_size
+
BlockShape
::
Block_K0
-
1
)
/
BlockShape
::
Block_K0
;
const
index_t
num_blocks_n1
=
(
hidden_size
+
BlockShape
::
Block_N1
-
1
)
/
BlockShape
::
Block_N1
;
using
a_thread_type
=
decltype
(
load_tile
(
a_sld_win0
));
statically_indexed_array
<
a_thread_type
,
2
>
as
;
auto
gld_a
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
a_store_
,
auto
i_access
,
PreNop
=
{})
{
async_load_tile_raw
(
a_store_
,
a_win
,
i_access
,
PreNop
{});
};
auto
move_a
=
[
&
]()
{
move_tile_window
(
a_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_K0
>
{}});
};
auto
sld_a
=
[
&
](
auto
&
a_
,
auto
&
win_
,
auto
i_access
)
{
load_tile_raw
(
a_
,
win_
,
i_access
);
};
auto
gld_g
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
g_
,
auto
i_access
,
PreNop
=
{})
{
if
constexpr
(
IsGateOnly
)
{
// TODO: hack!
if
constexpr
(
i_access
.
value
==
0
)
{
g_win
.
bottom_tensor_view_
=
g_view
;
}
else
if
constexpr
(
i_access
.
value
==
issues_g
/
2
)
{
g_win
.
bottom_tensor_view_
=
u_view
;
}
}
load_tile_raw
(
g_
,
g_win
,
i_access
,
FALSE
,
PreNop
{});
};
auto
move_g
=
[
&
]()
{
move_tile_window
(
g_win
,
{
number
<
0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
0
>
{}});
};
statically_indexed_array
<
d_thread_type
,
2
>
ds
;
auto
gld_d
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
d_
,
auto
i_access
,
PreNop
=
{})
{
load_tile_raw
(
d_
,
d_win
,
i_access
,
FALSE
,
PreNop
{});
};
auto
move_d
=
[
&
]()
{
// d move along gemm-n
move_tile_window
(
d_win
,
{
number
<
BlockShape
::
Block_N1
>
{},
number
<
0
>
{}});
};
auto
atomic_add_o
=
[
&
]
<
typename
PreNop
=
bool_constant
<
false
>>
(
auto
&
o_
,
auto
i_access
,
PreNop
=
{})
{
update_tile_raw
(
o_win
,
o_
,
i_access
,
TRUE
,
PreNop
{});
};
auto
acc_0
=
Policy
::
template
MakeCBlockTile_Gemm0
<
Problem
>();
auto
acc_1s
=
generate_tuple
(
[
&
](
auto
)
{
return
Policy
::
template
MakeCBlockTile_Gemm1
<
Problem
>();
},
number
<
2
>
{});
// clang-format off
auto
gemm_0
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm_0
)
>
;
constexpr
auto
repeat_sub
=
WarpGemm
::
get_num_of_access
();
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M0
;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K0
;
// loop order n->m->k
constexpr
auto
i_sub
=
i_access
%
repeat_sub
;
constexpr
auto
i_k
=
(
i_access
/
repeat_sub
)
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
WarpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
AWarpTensor
w_a
;
w_a
.
get_thread_buffer
()
=
t_a
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_k
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
BWarpTensor
w_b
;
w_b
.
get_thread_buffer
()
=
t_b
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_n
,
i_k
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
CWarpTensor
w_c
;
w_c
.
get_thread_buffer
()
=
t_c
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
warp_gemm_0
(
w_c
,
w_a
,
w_b
,
number
<
i_sub
>
{},
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
w_c
.
get_thread_buffer
());
};
// clang-format on
// clang-format off
auto
gemm_1
=
[
&
]
<
typename
PostNop
=
bool_constant
<
false
>>
(
auto
&
t_c
,
auto
&
t_a
,
auto
&
t_b
,
auto
i_access
,
PostNop
=
{})
{
using
WarpGemm
=
remove_cvref_t
<
decltype
(
warp_gemm_1
)
>
;
constexpr
auto
repeat_sub
=
WarpGemm
::
get_num_of_access
();
constexpr
auto
repeat_m
=
BlockShape
::
Repeat_M0
;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr
auto
repeat_k
=
BlockShape
::
Repeat_K0
;
// loop order n->m->k
constexpr
auto
i_sub
=
i_access
%
repeat_sub
;
constexpr
auto
i_k
=
(
i_access
/
repeat_sub
)
%
repeat_k
;
constexpr
auto
i_m
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
%
repeat_m
;
constexpr
auto
i_n
=
(
i_access
/
(
repeat_sub
*
repeat_k
))
/
repeat_m
;
using
AWarpTensor
=
typename
WarpGemm
::
AWarpTensor
;
using
BWarpTensor
=
typename
WarpGemm
::
BWarpTensor
;
using
CWarpTensor
=
typename
WarpGemm
::
CWarpTensor
;
using
AWarpDstr
=
typename
WarpGemm
::
AWarpDstr
;
using
BWarpDstr
=
typename
WarpGemm
::
BWarpDstr
;
using
CWarpDstr
=
typename
WarpGemm
::
CWarpDstr
;
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
b_warp_y_index_zeros
=
uniform_sequence_gen_t
<
BWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
b_warp_y_lengths
=
to_sequence
(
BWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
AWarpTensor
w_a
;
w_a
.
get_thread_buffer
()
=
t_a
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_k
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
BWarpTensor
w_b
;
w_b
.
get_thread_buffer
()
=
t_b
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_n
,
i_k
>
{},
b_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
b_warp_y_lengths
));
CWarpTensor
w_c
;
w_c
.
get_thread_buffer
()
=
t_c
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
warp_gemm_1
(
w_c
,
w_a
,
w_b
,
number
<
i_sub
>
{},
PostNop
{});
t_c
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
i_m
,
i_n
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
w_c
.
get_thread_buffer
());
};
// clang-format on
_Pragma
(
"clang diagnostic pop"
);
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
// mfma(that can reuse the B matrix) only affected by M repeat.
auto
pipeline_gemm0
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_sld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_a_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
SLD_A
)
sld_a
(
as
[
I1
],
a_sld_win1
,
number
<
NEXT_SCI
(
c_sld_a_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_A
)
gld_a
(
a_sst_win0
,
number
<
NEXT_SCI
(
c_gld_a_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I0
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
constexpr
auto
c_sld_a_1
=
MAKE_SC
();
constexpr
auto
c_gld_a_1
=
MAKE_SC
();
constexpr
auto
c_gld_b_1
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
SLD_A
)
sld_a
(
as
[
I0
],
a_sld_win0
,
number
<
NEXT_SCI
(
c_sld_a_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_A
)
gld_a
(
a_sst_win1
,
number
<
NEXT_SCI
(
c_gld_a_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_1
,
i_issue
)
>
{});
});
move_g
();
move_a
();
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
};
auto
pipeline_gemm0_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm0
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_0
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_0
(
acc_0
,
as
[
I0
],
gs
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_g
(
gs
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
});
block_sync_load_raw
(
issues_g
);
sld_a
(
as
[
I1
],
a_sld_win1
,
NEG1
);
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
constexpr
auto
last_nop
=
[
&
]()
{
if
constexpr
(
i_issue
==
(
total_loops
-
1
))
return
TRUE
;
else
return
FALSE
;
}();
gemm_0
(
acc_0
,
as
[
I1
],
gs
[
I1
],
i_issue
,
last_nop
);
// last gemm has nop
});
};
auto
y
=
Policy
::
template
MakeYBlockTile
<
Problem
>();
auto
pipeline_bridge
=
[
&
]()
{
// cast to Y data
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
store_tile
(
bridge_sst_win
,
y_pre
);
clear_tile
(
acc_1s
(
I0
));
// wave_barrier();
load_tile
(
y
,
bridge_sld_win
);
clear_tile
(
acc_1s
(
I1
));
};
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto
pipeline_gemm1
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
constexpr
auto
c_gst_o_0
=
MAKE_SC
();
constexpr
auto
c_gld_b_1
=
MAKE_SC
();
constexpr
auto
c_gst_o_1
=
MAKE_SC
();
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I0
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GST_O
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_0
,
i_issue
)
>
{});
}
});
move_d
();
// move_o();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_1
,
i_issue
)
>
{});
if
constexpr
(
slot
&
GST_O
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_1
,
i_issue
)
>
{});
}
});
move_d
();
};
auto
pipeline_gemm1_head
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gld_b_0
=
MAKE_SC
();
// compute buffer 0
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I0
],
y
,
ds
[
I0
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GLD_B
)
gld_d
(
ds
[
I1
],
number
<
NEXT_SCI
(
c_gld_b_0
,
i_issue
)
>
{});
});
move_d
();
};
auto
pipeline_gemm1_tail
=
[
&
]()
{
constexpr
index_t
total_loops
=
issues_gemm1
;
constexpr
auto
sr
=
Policy
::
template
GetSequencer_1
<
Problem
>();
static_assert
(
sr
.
size
()
==
total_loops
);
constexpr
auto
c_gst_o_0
=
MAKE_SC
();
// compute buffer 1
static_for
<
0
,
total_loops
,
1
>
{}([
&
](
auto
i_issue
)
{
gemm_1
(
acc_1s
[
I1
],
y
,
ds
[
I1
],
i_issue
);
constexpr
index_t
slot
=
sr
.
at
(
i_issue
);
if
constexpr
(
slot
&
GST_O
)
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I0
]);
atomic_add_o
(
out
,
number
<
NEXT_SCI
(
c_gst_o_0
,
i_issue
)
>
{});
}
});
{
auto
out
=
cast_tile
<
ODataType
>
(
acc_1s
[
I1
]);
atomic_add_o
(
out
,
NEG1
);
}
};
// start of pipeline
// clang-format off
gld_a
(
a_sst_win0
,
NEG1
,
TRUE
);
gld_g
(
gs
[
I0
],
NEG1
,
TRUE
);
move_a
();
move_g
();
clear_tile
(
acc_0
);
// preload for next round
gld_a
(
a_sst_win1
,
NEG1
);
gld_g
(
gs
[
I1
],
NEG1
);
// make sure a,g loaded
block_sync_load_raw
(
issues_a
+
issues_g
);
lds_load_fence
();
// we manually unroll double buffer inside hot loop
const
index_t
iters_0
=
(
num_blocks_k0
-
2
)
/
2
;
index_t
i_0
=
0
;
// (void)i_0; (void)iters_0; (void)pipeline_gemm0;
while
(
i_0
++
<
iters_0
)
{
pipeline_gemm0
();
}
pipeline_gemm0_tail
();
pipeline_bridge
();
const
index_t
iters_1
=
(
num_blocks_n1
-
2
)
/
2
;
index_t
i_1
=
0
;
// (void) i_1; (void)iters_1; (void)pipeline_gemm1;
pipeline_gemm1_head
();
while
(
i_1
++
<
iters_1
)
{
pipeline_gemm1
();
}
pipeline_gemm1_tail
();
// clang-format on
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
struct
FusedMoeGemmPipelineFlatmmPolicy
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetAsyncCopyDwords
()
{
// TODO: always 1 dword
return
1
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_A
()
{
// using async
constexpr
index_t
copy_bytes
=
4
*
GetAsyncCopyDwords
();
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
ADataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_G
()
{
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
GDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_D
()
{
constexpr
index_t
copy_bytes
=
[
&
]()
{
return
16
;
}();
constexpr
index_t
data_bytes
=
sizeof
(
typename
Problem
::
DDataType
);
static_assert
(
copy_bytes
%
data_bytes
==
0
);
return
copy_bytes
/
data_bytes
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignment_O
()
{
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
1
)
{
// pack fp16/bf16 atomic
static_assert
(
sizeof
(
typename
Problem
::
ODataType
)
==
2
);
return
2
;
}
else
if
constexpr
(
Problem
::
Traits
::
OAtomic
==
2
)
{
// fp32 atomic
return
1
;
}
else
{
return
16
/
sizeof
(
typename
Problem
::
ODataType
);
}
}
template
<
typename
DataType_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack
()
{
// TODO: this is for 3d layout
return
16
/
sizeof
(
remove_cvref_t
<
DataType_
>
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack_A
()
{
return
GetSmemKPack
<
typename
Problem
::
ADataType
>
();
}
// used for bridge LDS shuffle
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemKPack_Y
()
{
// TODO: this should match mfma layout
return
16
/
sizeof
(
typename
Problem
::
YDataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_A
()
{
constexpr
auto
a_sld_desc
=
MakeLdsLoadDesc_A
<
Problem
>
();
constexpr
auto
a_sst_desc
=
MakeLdsStoreDesc_A
<
Problem
>
();
static_assert
(
a_sld_desc
.
get_element_space_size
()
==
a_sst_desc
.
get_element_space_size
());
return
a_sld_desc
.
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize_Bridge
()
{
constexpr
auto
bridge_sld_desc
=
MakeBridgeLdsLoadDesc
<
Problem
>
();
constexpr
auto
bridge_sst_desc
=
MakeBridgeLdsStoreDesc
<
Problem
>
();
static_assert
(
bridge_sld_desc
.
get_element_space_size
()
==
bridge_sst_desc
.
get_element_space_size
());
return
bridge_sld_desc
.
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
a_lds
=
GetSmemSize_A
<
Problem
>
();
constexpr
index_t
bridge_lds
=
GetSmemSize_Bridge
<
Problem
>
();
return
max
(
a_lds
,
bridge_lds
);
}
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"not not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
,
M_lan
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
// optimized version for async, not same as simple MXK dist(pay attention!!)
template
<
index_t
MPerBlock
,
index_t
KPerBlock
,
index_t
NumWarps
,
index_t
Alignment
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_SimpleMxK_Async
()
{
constexpr
index_t
K_vec
=
Alignment
;
constexpr
index_t
K_rem
=
KPerBlock
/
K_vec
;
if
constexpr
(
get_warp_size
()
<=
K_rem
)
{
static_assert
(
K_rem
%
get_warp_size
()
==
0
);
constexpr
index_t
K_lan
=
get_warp_size
();
// lane within same wave is along gemm-k
constexpr
index_t
K_wav
=
K_rem
/
get_warp_size
();
static_assert
(
K_wav
<=
NumWarps
,
"do not support thread has repeat along K yet"
);
constexpr
index_t
M_wav
=
NumWarps
/
K_wav
;
static_assert
(
MPerBlock
%
M_wav
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
M_wav
;
// NOTE: no swap, but hard to avoid LDS bank conflict
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M_rep
,
M_wav
>
,
sequence
<
K_wav
,
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
1
,
0
>
,
sequence
<
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
2
>>
{});
}
else
{
constexpr
index_t
K_lan
=
K_rem
;
constexpr
index_t
M_lan
=
get_warp_size
()
/
K_lan
;
constexpr
index_t
M_wav
=
NumWarps
;
static_assert
(
MPerBlock
%
(
M_lan
*
M_wav
)
==
0
,
"this tile size is too small please check"
);
constexpr
index_t
M_rep
=
MPerBlock
/
(
M_lan
*
M_wav
);
// NOTE: swapped for LDS load bank conflict free
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
// Note M_wave(num waves) is the fastest dim, different from sipmle 2d
// distribution
tuple
<
sequence
<
M_rep
,
M_lan
,
M_wav
>
,
sequence
<
K_lan
,
K_vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
}
template
<
index_t
WarpPerBlock_N_
,
index_t
WarpPerBlock_K_
,
index_t
Repeat_N_
,
index_t
Repeat_K_
,
index_t
WarpSize_
,
index_t
Alignment_
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_Nr_Kr_W
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Repeat_N_
,
WarpPerBlock_N_
>
,
sequence
<
Repeat_K_
,
WarpPerBlock_K_
>
,
sequence
<
WarpSize_
,
Alignment_
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
>>
,
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_A
()
{
constexpr
index_t
Block_M_
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K_
=
Problem
::
BlockShape
::
Block_K0
;
constexpr
index_t
NumWarps_
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
Alignment_
=
GetAlignment_A
<
Problem
>
();
return
MakeGlobalTileDistribution_SimpleMxK_Async
<
Block_M_
,
Block_K_
,
NumWarps_
,
Alignment_
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_G
()
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
// number<S_::WarpPerBlock_N0>{}.rrr();
// number<S_::Repeat_N0>{}.eee();
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N0
,
S_
::
WarpPerBlock_K0
,
S_
::
Repeat_N0
,
/// hidden_radio_0,
S_
::
Repeat_K0
,
get_warp_size
(),
GetAlignment_G
<
Problem
>
()
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_D
()
{
constexpr
auto
PermuteEnum
=
Problem
::
Traits
::
PermuteEnum
;
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
PermuteEnum
==
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
)
{
return
MakeGlobalTileDistribution_Nr_Kr_W
<
S_
::
WarpPerBlock_N1
,
S_
::
WarpPerBlock_K1
,
S_
::
Repeat_N1
,
S_
::
Repeat_K1
,
get_warp_size
(),
GetAlignment_D
<
Problem
>
()
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeGlobalTileDistribution_O
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
// using CDataType = typename WarpGemm::CDataType;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
KVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
KPack
;
// pad between warps
static_assert
(
Block_K
%
KVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
KVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_K
=
Problem
::
BlockShape
::
Block_K0
;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
constexpr
index_t
NumWarps
=
Problem
::
BlockShape
::
NumWarps
;
constexpr
index_t
KPack
=
GetSmemKPack_A
<
Problem
>
();
// LDS
constexpr
index_t
KVector
=
GetAlignment_A
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
KPack
;
// pad between warps
static_assert
(
Block_K
%
KVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
KVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>=
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
wavesPerM
>
{})),
make_merge_transform
(
make_tuple
(
number
<
wavesPerK
>
{},
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KPack
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
NumIssues
>
{},
number
<
LaneGroups
>
{},
number
<
NumWarps
>
{})),
make_merge_transform
(
make_tuple
(
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsLoadDesc
()
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
0
;
// pad between warps
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
KPad
>
{},
number
<
1
>
{}),
number
<
KVector
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreDesc
()
{
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M0
;
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N0
;
constexpr
index_t
KVector
=
GetSmemKPack_Y
<
Problem
>
();
// async copy 1 dword
constexpr
index_t
KPad
=
0
;
// KVector; // pad between warps
constexpr
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
+
KPad
>
{},
number
<
1
>
{}),
number
<
KVector
>
{},
number
<
1
>
{});
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBridgeLdsStoreForUKDesc
()
{
constexpr
index_t
WarpPerBlock_N
=
Problem
::
BlockShape
::
WarpPerBlock_N0
;
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N0
;
constexpr
index_t
Repeat_M
=
Problem
::
BlockShape
::
Repeat_M0
;
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
KPack
=
kABKPerLane
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Repeat_M
>
{},
// m
number
<
Repeat_N
>
{},
// n
number
<
WarpPerBlock_N
>
{},
// n
number
<
kABKLane
>
{},
// n
number
<
kAMLane
>
{},
// m
number
<
KPack
>
{}),
// n
make_tuple
(
number
<
Repeat_N
*
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// m
number
<
WarpPerBlock_N
*
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
kABKLane
*
kAMLane
*
KPack
>
{},
// n
number
<
kAMLane
*
KPack
>
{},
// n
number
<
KPack
>
{},
// m
number
<
1
>
{}),
// n
number
<
KPack
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
desc
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_merge_transform
(
make_tuple
(
number
<
Repeat_N
>
{},
number
<
WarpPerBlock_N
>
{},
number
<
kABKLane
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
0
,
4
>
{},
sequence
<
1
,
2
,
3
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm0
()
{
using
S_
=
typename
Problem
::
BlockShape
;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_avv
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
wg_ctrl
>
,
2
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
32
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
<
wg_ctrl
>
,
2
>>
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSequencer_0
()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 0
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 1
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 2
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 3
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 4
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 5
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 6
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 7
return
seq_all
;
// clang-format on
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 0
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
GLD_B
,
GLD_A
,
// 1
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
// 2
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
,
GLD_B
,
SLD_A
>
{};
// 3
return
seq_all
;
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSequencer_1
()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M1
==
32
&&
S_
::
Warp_N1
==
32
&&
S_
::
Warp_K1
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 0
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 1
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 2
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 3
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 4
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 5
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 6
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 7
return
seq_all
;
// clang-format on
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M1
==
32
&&
S_
::
Warp_N1
==
32
&&
S_
::
Warp_K1
==
16
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
256
&&
S_
::
Block_K0
==
128
&&
S_
::
Block_N1
==
128
)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr
auto
seq_all
=
// 0 1 2 3 4 5 6 7
sequence
<
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 0
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
GLD_B
,
GST_O
,
// 1
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
// 2
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
,
GLD_B
,
0
>
{};
// 3
return
seq_all
;
// clang-format on
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetWarpGemm1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
constexpr
auto
wg_ctrl
=
WGAttrCtlEnum
::
Raw_avv
;
// TODO: ugly
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
16
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
<
wg_ctrl
>
,
2
>>
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
int8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
int8_t
>
&&
S_
::
Warp_M0
==
32
&&
S_
::
Warp_N0
==
32
&&
S_
::
Warp_K0
==
32
)
{
return
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
<
wg_ctrl
>
,
2
>>
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeCBlockTile_Gemm0
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm0
<
Problem
>
())
>
;
using
CDataType
=
typename
WarpGemm
::
CDataType
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M0
,
S_
::
WarpPerBlock_M0
>
,
sequence
<
S_
::
Repeat_N0
,
S_
::
WarpPerBlock_N0
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeCBlockTile_Gemm1
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
using
CDataType
=
typename
WarpGemm
::
CDataType
;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S_
::
Repeat_M1
,
S_
::
WarpPerBlock_M1
>
,
sequence
<
S_
::
Repeat_N1
,
S_
::
WarpPerBlock_N1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WarpGemm
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// this is used as A matrix for 2nd gemm
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeYTileDistribution
()
{
using
S_
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
using
WarpGemm
=
remove_cvref_t
<
decltype
(
GetWarpGemm1
<
Problem
>
())
>
;
// TODO: all waves a along different N, but same M
constexpr
auto
y_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
S_
::
WarpPerBlock_M1
>
,
tuple
<
sequence
<
S_
::
Repeat_M1
>
,
sequence
<
S_
::
Repeat_K1
>>
,
tuple
<
sequence
<
0
>>
,
tuple
<
sequence
<
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
y_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
y_outer_dstr_enc
,
typename
WarpGemm
::
AWarpDstrEncoding
{});
constexpr
auto
y_block_dstr
=
make_static_tile_distribution
(
y_block_dstr_encode
);
return
y_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeYBlockTile
()
{
constexpr
auto
y_block_dstr
=
MakeYTileDistribution
<
Problem
>
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
typename
Problem
::
YDataType
>
(
y_block_dstr
);
return
y_block_tensor
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetUK_0
()
{
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
bf16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
Flatmm_32x512x128_1x4x1_16x16x32_BF16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
ADataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
GDataType
,
ck_tile
::
fp16_t
>
&&
S_
::
Block_M0
==
32
&&
S_
::
Block_N0
==
512
&&
S_
::
Block_K0
==
128
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
Flatmm_32x512x128_1x4x1_16x16x32_FP16
{};
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetUK_1
()
{
using
S_
=
typename
Problem
::
BlockShape
;
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
FlatmmSn_32x128x512_1x4x1_16x16x32_BF16
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
YDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
DDataType
,
ck_tile
::
fp16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
TopkWeightDataType
,
float
>
&&
S_
::
Block_M1
==
32
&&
S_
::
Block_N1
==
128
&&
S_
::
Block_K1
==
512
&&
S_
::
Warp_M0
==
16
&&
S_
::
Warp_N0
==
16
&&
S_
::
Warp_K0
==
32
)
{
return
FlatmmSn_32x128x512_1x4x1_16x16x32_FP16
{};
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace
ck_tile
{
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template
<
typename
Problem_
,
typename
Policy_
=
FusedMoeGemmPipelineFlatmmPolicy
>
struct
FusedMoeGemmPipeline_FlatmmUk
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
// this is FusedMoeGemmShape
using
ADataType
=
typename
Problem
::
ADataType
;
using
GDataType
=
typename
Problem
::
GDataType
;
using
DDataType
=
typename
Problem
::
DDataType
;
using
AccDataType
=
typename
Problem
::
AccDataType
;
using
ODataType
=
typename
Problem
::
ODataType
;
using
AScaleDataType
=
typename
Problem
::
AScaleDataType
;
using
GScaleDataType
=
typename
Problem
::
GScaleDataType
;
using
DScaleDataType
=
typename
Problem
::
DScaleDataType
;
using
YSmoothScaleDataType
=
typename
Problem
::
YSmoothScaleDataType
;
using
TopkWeightDataType
=
typename
Problem
::
TopkWeightDataType
;
using
IndexDataType
=
typename
Problem
::
IndexDataType
;
using
YDataType
=
typename
Problem
::
YDataType
;
using
Traits
=
typename
Problem
::
Traits
;
static
constexpr
bool
IsGateOnly
=
Traits
::
IsGateOnly
;
static
constexpr
bool
UseSmoothQuant
=
Traits
::
UseSmoothQuant
;
static
constexpr
bool
PadHiddenSize
=
Traits
::
PadHiddenSize
;
static
constexpr
bool
PadIntermediateSize
=
Traits
::
PadIntermediateSize
;
static
constexpr
index_t
kAlignmentA
=
Policy
::
template
GetAlignment_A
<
Problem
>();
static
constexpr
index_t
kAlignmentG
=
Policy
::
template
GetAlignment_G
<
Problem
>();
static
constexpr
index_t
kAlignmentD
=
Policy
::
template
GetAlignment_D
<
Problem
>();
static
constexpr
index_t
kAlignmentO
=
Policy
::
template
GetAlignment_O
<
Problem
>();
static
constexpr
index_t
SLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
SLD_A
);
static
constexpr
index_t
GLD_A
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_A
);
static
constexpr
index_t
GLD_B
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GLD_B
);
static
constexpr
index_t
GST_O
=
static_cast
<
index_t
>
(
FusedMoeGemmPipelineSequencerEnum
::
GST_O
);
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
// minimize occupancy
return
2
;
}
}();
static
constexpr
const
char
*
name
=
"flatmm_uk"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_0
=
Policy
::
template
GetUK_0
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_1
=
Policy
::
template
GetUK_1
<
Problem
>().
GetSmemSize
();
constexpr
index_t
smem_bridge
=
BlockShape
::
Block_M0
*
BlockShape
::
Block_N0
*
sizeof
(
YDataType
);
return
max
(
smem_0
,
max
(
smem_1
,
smem_bridge
));
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetACoord
()
{
constexpr
auto
a_dist
=
Policy
::
template
MakeGlobalTileDistribution_A
<
Problem
>();
const
auto
a_coord
=
a_dist
.
calculate_index
();
return
a_coord
;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE
static
auto
GetOCoord
()
{
constexpr
auto
o_dist
=
Policy
::
template
MakeOGlobalTileDistribution
<
Problem
>();
const
auto
o_coord
=
o_dist
.
calculate_index
();
return
o_coord
;
}
CK_TILE_DEVICE
constexpr
auto
GetNumRowCoords_A
()
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
return
MRepeat
;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE
auto
GetRowCoords_A
(
index_t
base_offset
)
{
constexpr
index_t
KLans
=
BlockShape
::
Block_K0
/
kAlignmentA
;
constexpr
index_t
MLans
=
BlockShape
::
BlockSize
/
KLans
;
constexpr
index_t
MRepeat
=
BlockShape
::
Block_M0
/
MLans
;
auto
base_coord
=
threadIdx
.
x
/
KLans
+
base_offset
;
array
<
index_t
,
MRepeat
>
coords
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLans
;
});
return
coords
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetRowID
(
const
ROW_COORDS
coords
,
const
IndexDataType
*
sorted_token_ids_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
index_t
,
n_size
>
row_ids
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
row_ids
.
at
(
i
)
=
sorted_token_ids_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
row_ids
;
}
template
<
typename
ROW_COORDS
>
CK_TILE_DEVICE
auto
GetWeightScale
(
const
ROW_COORDS
coords
,
const
TopkWeightDataType
*
sorted_weight_ptr
)
{
constexpr
index_t
n_size
=
coords
.
size
();
array
<
TopkWeightDataType
,
n_size
>
w
;
static_for
<
0
,
n_size
,
1
>
{}([
&
](
auto
i
)
{
w
.
at
(
i
)
=
sorted_weight_ptr
[
coords
[
i
]];
// base_coord + i * MLans;
});
return
w
;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE
auto
GetRowCoords_O
(
index_t
base_offset
)
{
constexpr
index_t
MLanes
=
BlockShape
::
Warp_M1
;
constexpr
index_t
Repeat_M
=
BlockShape
::
Repeat_M1
;
auto
base_coord
=
threadIdx
.
x
%
MLanes
+
base_offset
;
array
<
index_t
,
Repeat_M
>
coords
;
static_for
<
0
,
Repeat_M
,
1
>
{}([
&
](
auto
i
)
{
coords
.
at
(
i
)
=
base_coord
+
i
*
MLanes
;
});
return
coords
;
}
template
<
typename
Karg
>
CK_TILE_DEVICE
auto
operator
()(
const
Karg
&
kargs
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
sorted_tile_id
,
index_t
intermediate_tile_id
)
{
constexpr
index_t
hidden_radio_0
=
IsGateOnly
?
1
:
2
;
ck_tile
::
index_t
shared_intermediate_size_0
=
kargs
.
intermediate_size
;
ck_tile
::
index_t
shared_intermediate_size_1
=
kargs
.
intermediate_size
/
hidden_radio_0
;
index_t
nr_0
=
shared_intermediate_size_0
/
BlockShape
::
Warp_N0
;
// divide N in W
index_t
kr_0
=
kargs
.
hidden_size
/
BlockShape
::
Warp_K0
;
// divide K in W
index_t
nr_1
=
kargs
.
hidden_size
/
BlockShape
::
Warp_N1
;
index_t
kr_1
=
shared_intermediate_size_1
/
BlockShape
::
Warp_K1
;
const
IndexDataType
expert_id
=
__builtin_amdgcn_readfirstlane
(
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_expert_ids_ptr
)[
sorted_tile_id
]);
index_t
expert_stride_0
=
shared_intermediate_size_0
*
kargs
.
hidden_size
;
index_t
expert_stride_1
=
shared_intermediate_size_1
*
kargs
.
hidden_size
;
// nr*kr*w
index_t
interm_idx_nr0
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Nr0
);
// intermediate_tile_id * Block_N / (N in W)
index_t
interm_idx_kr1
=
__builtin_amdgcn_readfirstlane
(
intermediate_tile_id
*
BlockShape
::
Block_Kr1
);
// intermediate_tile_id * Block_N / (N in W)
auto
row_coords_a
=
GetRowCoords_A
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
row_ids_a
=
GetRowID
(
row_coords_a
,
reinterpret_cast
<
const
IndexDataType
*>
(
kargs
.
sorted_token_ids_ptr
));
auto
a_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_K0
/
kAlignmentA
)
*
kAlignmentA
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
a_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ADataType
));
auto
g_win
=
[
&
]()
{
const
GDataType
*
g_ptr
=
reinterpret_cast
<
const
GDataType
*>
(
kargs
.
g_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_0
+
interm_idx_nr0
*
kr_0
*
BlockShape
::
Block_W0
;
auto
g_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
g_ptr
,
make_tuple
(
nr_0
,
kr_0
,
number
<
BlockShape
::
Block_W0
>
{}),
make_tuple
(
kr_0
*
BlockShape
::
Block_W0
,
number
<
BlockShape
::
Block_W0
>
{},
1
),
number
<
kAlignmentG
>
{},
number
<
1
>
{});
auto
g_window_
=
make_tile_window_linear_raw
(
g_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr0
>
{},
number
<
BlockShape
::
Block_Kr0
>
{},
number
<
BlockShape
::
Block_W0
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_G
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
g_window_
;
}();
auto
g_res
=
g_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
auto
g_coords
=
generate_tuple
([
&
](
auto
i
)
{
return
g_win
.
cached_coords_
[
i
].
get_offset
();
},
number
<
decltype
(
g_win
)
::
NumAccess_NonLinear
>
{});
const
auto
d_win
=
[
&
]()
{
const
DDataType
*
d_ptr
=
reinterpret_cast
<
const
DDataType
*>
(
kargs
.
d_ptr
)
+
static_cast
<
long_index_t
>
(
expert_id
)
*
expert_stride_1
+
interm_idx_kr1
*
BlockShape
::
Block_W1
;
// note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
const
auto
d_view_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
d_ptr
,
make_tuple
(
nr_1
,
kr_1
,
BlockShape
::
Block_W1
),
make_tuple
(
kr_1
*
BlockShape
::
Block_W1
,
BlockShape
::
Block_W1
,
1
),
number
<
kAlignmentD
>
{},
number
<
1
>
{});
const
auto
d_window_
=
make_tile_window_linear_raw
(
d_view_
,
make_tuple
(
number
<
BlockShape
::
Block_Nr1
>
{},
number
<
BlockShape
::
Block_Kr1
>
{},
number
<
BlockShape
::
Block_W1
>
{}),
{
0
,
0
,
0
},
Policy
::
template
MakeGlobalTileDistribution_D
<
Problem
>(),
sequence
<
0
,
1
,
1
>
{});
return
d_window_
;
}();
auto
d_res
=
d_win
.
get_bottom_tensor_view
().
get_buffer_view
().
cached_buf_res_
;
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// wg |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto
d_coords
=
[
&
]()
{
constexpr
index_t
Nr_
=
2
;
constexpr
index_t
Nw_
=
4
;
constexpr
index_t
Kr0_
=
4
;
constexpr
index_t
Kr1_
=
4
;
constexpr
index_t
Kl_
=
4
;
constexpr
index_t
Nl_
=
16
;
constexpr
index_t
Kv_
=
8
;
constexpr
index_t
W_
=
Kl_
*
Nl_
*
Kv_
;
constexpr
index_t
num_offsets_
=
Nr_
*
Kr0_
;
index_t
base_os_
=
(
threadIdx
.
x
%
64
)
*
Kv_
+
(
threadIdx
.
x
/
64
)
*
shared_intermediate_size_1
*
Nl_
;
// Kr0_ * Kr1_ * W_;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
i_nr_
=
number
<
i
%
Nr_
>
{};
constexpr
auto
i_kr0_
=
number
<
i
/
Nr_
>
{};
return
i_nr_
*
shared_intermediate_size_1
*
Nw_
*
Nl_
+
i_kr0_
*
Kr1_
*
W_
+
base_os_
;
},
number
<
num_offsets_
>
{});
}();
auto
o_coords
=
generate_tuple
(
[
&
](
auto
i
)
{
return
row_ids_a
[
i
]
*
kargs
.
stride_token
+
threadIdx
.
x
%
(
BlockShape
::
Block_N1
/
kAlignmentO
)
*
kAlignmentO
;
},
number
<
row_ids_a
.
size
()
>
{});
auto
o_flags
=
generate_tuple
([
&
](
auto
i
)
{
return
cmp_lt_to_exec
(
row_ids_a
[
i
],
kargs
.
num_tokens
);
},
number
<
row_ids_a
.
size
()
>
{});
auto
bridge_sst_win
=
[
&
]()
{
constexpr
auto
desc_
=
Policy
::
template
MakeBridgeLdsStoreForUKDesc
<
Problem
>();
constexpr
auto
dist_
=
Policy
::
template
GetUK_0
<
Problem
>().
MakeCBlockDist
();
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
YDataType
*>
(
smem
),
desc_
),
desc_
.
get_lengths
(),
{
0
,
0
},
dist_
);
}();
auto
o_res
=
make_wave_buffer_resource
(
reinterpret_cast
<
const
ODataType
*>
(
kargs
.
o_ptr
),
kargs
.
num_tokens
*
kargs
.
stride_token
*
sizeof
(
ODataType
));
auto
row_coords_o
=
GetRowCoords_O
(
sorted_tile_id
*
BlockShape
::
Block_M0
);
auto
w_scale
=
GetWeightScale
(
row_coords_o
,
reinterpret_cast
<
const
TopkWeightDataType
*>
(
kargs
.
sorted_weight_ptr
));
auto
uk_0
=
Policy
::
template
GetUK_0
<
Problem
>();
auto
acc_0
=
uk_0
(
a_res
,
a_coords
,
g_res
,
g_coords
,
smem
,
kargs
.
hidden_size
,
BlockShape
::
Block_K0
,
// tile offset for B matrix each unroll
BlockShape
::
Block_Kr0
*
BlockShape
::
Block_W0
);
// tile offset for B matrix each unroll
sweep_tile
(
acc_0
,
[
&
](
auto
idx0
,
auto
idx1
)
{
fp32x2_t
v_
{
acc_0
(
idx0
),
acc_0
(
idx1
)};
typename
Problem
::
GateActivation
{}(
v_
,
v_
);
acc_0
(
idx0
)
=
v_
.
x
;
acc_0
(
idx1
)
=
v_
.
y
;
},
sequence
<
1
,
2
>
{});
auto
y_pre
=
cast_tile
<
YDataType
>
(
acc_0
);
block_sync_lds
();
store_tile
(
bridge_sst_win
,
y_pre
);
block_sync_lds
();
auto
uk_1
=
Policy
::
template
GetUK_1
<
Problem
>();
uk_1
(
d_res
,
d_coords
,
o_res
,
o_coords
,
o_flags
,
smem
,
kargs
.
hidden_size
,
// total n number
w_scale
,
BlockShape
::
Block_Nr1
*
kr_1
*
BlockShape
::
Block_W1
,
// along N
BlockShape
::
Block_N1
);
// along N
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// TODO: alow 2 gemm have different type
template
<
typename
ADataType_
,
typename
GDataType_
,
typename
DDataType_
,
typename
AccDataType_
,
typename
ODataType_
,
typename
AScaleDataType_
,
typename
GScaleDataType_
,
typename
DScaleDataType_
,
typename
YSmoothScaleDataType_
,
typename
TopkWeightDataType_
,
typename
IndexDataType_
,
// data type for all indexing
typename
GateActivation_
,
// = ck_tile::element_wise::Silu,
typename
BlockShape_
,
// shoule be FusedMoeGemmShape
typename
Traits_
>
struct
FusedMoeGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GDataType
=
remove_cvref_t
<
GDataType_
>
;
using
DDataType
=
remove_cvref_t
<
DDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
AScaleDataType
=
remove_cvref_t
<
AScaleDataType_
>
;
using
GScaleDataType
=
remove_cvref_t
<
GScaleDataType_
>
;
using
DScaleDataType
=
remove_cvref_t
<
DScaleDataType_
>
;
using
YSmoothScaleDataType
=
remove_cvref_t
<
YSmoothScaleDataType_
>
;
using
TopkWeightDataType
=
remove_cvref_t
<
TopkWeightDataType_
>
;
using
IndexDataType
=
remove_cvref_t
<
IndexDataType_
>
;
// the input for next gemm should have same time as
using
YDataType
=
ADataType
;
using
GateActivation
=
remove_cvref_t
<
GateActivation_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
class
FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute
=
0
,
b_nr_kr_kw_nw_kv
=
1
,
// 0,1,3,4,2,5
b_nr_kr_waveflatten
=
b_nr_kr_kw_nw_kv
,
};
template
<
bool
IsGateOnly_
,
bool
UseSmoothQuant_
,
index_t
OAtomic_
,
// 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum
PermuteEnum_
=
FusedMoeGemmWeightPermuteEnum
::
b_nr_kr_waveflatten
,
bool
PadHiddenSize_
=
false
,
bool
PadIntermediateSize_
=
false
>
struct
FusedMoeGemmTraits
{
// Gate+Up or Gate only
static
constexpr
bool
IsGateOnly
=
IsGateOnly_
;
static
constexpr
bool
UseSmoothQuant
=
UseSmoothQuant_
;
static
constexpr
index_t
OAtomic
=
OAtomic_
;
static
constexpr
FusedMoeGemmWeightPermuteEnum
PermuteEnum
=
PermuteEnum_
;
static
constexpr
bool
PadHiddenSize
=
PadHiddenSize_
;
static
constexpr
bool
PadIntermediateSize
=
PadIntermediateSize_
;
};
// Note: this need to be a bit mask
enum
class
FusedMoeGemmPipelineSequencerEnum
{
SLD_A
=
1
<<
0
,
// shared load a
SLD_B
=
1
<<
1
,
GLD_A
=
1
<<
2
,
// global load a
GLD_B
=
1
<<
3
,
SST_A
=
1
<<
4
,
// shared store a
SST_B
=
1
<<
5
,
GST_O
=
1
<<
6
,
// global store out
};
}
// namespace ck_tile
Prev
1
…
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment