Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
970be900
Commit
970be900
authored
Oct 21, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into letaoqin/update_layernorm
parents
009cce41
95e722a3
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
96 additions
and
48 deletions
+96
-48
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+16
-14
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+18
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+46
-21
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+2
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+13
-2
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-1
No files found.
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
970be900
...
...
@@ -191,7 +191,9 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
if (a.num_splits <= 16) {{
if (a.num_splits <= 8) {{
kernel_runner<3>::run(s, a);
}} else if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a);
}} else if (a.num_splits <= 32) {{
kernel_runner<5>::run(s, a);
...
...
@@ -239,7 +241,7 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}
/2
, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
...
...
@@ -551,10 +553,10 @@ class FmhaFwdSplitKVCombineKernel:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdTileSize
(
128
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
128
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
128
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
32
,
32
,
16
,
-
1
),
'32'
:
FmhaFwdTileSize
(
32
,
64
,
16
,
32
,
32
,
32
,
2
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'64'
:
FmhaFwdTileSize
(
64
,
64
,
32
,
64
,
32
,
64
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'128'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
128
,
32
,
128
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
'256'
:
FmhaFwdTileSize
(
64
,
128
,
32
,
256
,
32
,
256
,
4
,
1
,
1
,
16
,
16
,
16
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
...
...
@@ -568,16 +570,16 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
def
get_fmha_fwd_splitkv_combine_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
256
,
-
1
),
'32'
:
FmhaFwdSplitKVCombineTileSize
(
16
,
16
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
64
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
32
,
128
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
256
,
-
1
),
'64'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
32
,
-
1
),
'128'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
64
,
-
1
),
'256'
:
FmhaFwdSplitKVCombineTileSize
(
64
,
128
,
-
1
),
}
else
:
return
None
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
970be900
...
...
@@ -12,6 +12,16 @@ namespace detail {
template
<
index_t
N
>
struct
log2
;
template
<
>
struct
log2
<
4
>
:
std
::
integral_constant
<
index_t
,
2
>
{
};
template
<
>
struct
log2
<
8
>
:
std
::
integral_constant
<
index_t
,
3
>
{
};
template
<
>
struct
log2
<
16
>
:
std
::
integral_constant
<
index_t
,
4
>
{
...
...
@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
if
constexpr
(
kHeadDimV
<=
32
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
128
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
256
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
}
}();
...
...
@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
970be900
...
...
@@ -10,11 +10,26 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template
<
index_t
BlockSize
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
{
constexpr
index_t
PixelsPerThread
=
(
M
*
N
)
/
BlockSize
;
static_assert
(
0
<
PixelsPerThread
);
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
constexpr
index_t
NPerThread
=
min
(
MaxNPerThread
,
PixelsPerThread
);
return
NPerThread
;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
return
16
/
sizeof
(
LSEDataType
);
return
GetVectorSizeForTile
<
Problem
::
kBlockSize
,
Problem
::
kMaxSplits
,
Problem
::
kM0
,
typename
Problem
::
LSEDataType
>
();
}
template
<
typename
Problem
>
...
...
@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
// shape=[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPerThread
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPerThread
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
TotalWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
TotalWarps
*
MThreadsPerWarp
);
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
kNumWarps
*
MThreadsPerWarp
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
Total
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
static_assert
(
MPerThread
*
kNum
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
Total
Warps
,
MThreadsPerWarp
>
,
tuple
<
sequence
<
MPerThread
,
kNum
Warps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
...
...
@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence
<
0
,
1
>>
{});
}
// 3d + padding, [kMaxSplits, kM0]
// 3d + padding,
shape=
[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsStoreBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
...
@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return
lse_acc_lds_block_desc
;
}
// 3d + padding, [kM0, kMaxSplits]
// 3d + padding,
shape=
[kM0, kMaxSplits]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
...
@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
max
(
Problem
::
kMaxSplits
,
get_warp_size
())
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NThreads
=
get_warp_size
()
;
constexpr
index_t
NThreads
=
4
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
M
Threads
*
MPerThread
==
kMPerBlock
);
static_assert
(
M
Warps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M
Threads
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
>>
,
tuple
<
sequence
<
M
Warps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
970be900
...
...
@@ -115,7 +115,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
256
;
static
constexpr
index_t
kNumWarps
=
kM0_
/
(
get_warp_size
()
/
4
);
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
970be900
...
...
@@ -88,22 +88,33 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
static_assert
(
WarpGemmM
==
32
);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
...
...
include/ck_tile/ops/gemm.hpp
View file @
970be900
...
...
@@ -23,12 +23,12 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment