Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4885c38a
Commit
4885c38a
authored
Sep 03, 2024
by
aska-0096
Browse files
Merge branch 'transpose_opt' of
https://github.com/ROCm/composable_kernel
into rowwise_opt
parents
cbf14ee1
7c8e92fa
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1886 additions
and
343 deletions
+1886
-343
example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc
..._convnd_activ/convscale_reduce/run_convnd_fwd_example.inc
+98
-0
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+37
-3
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+13
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+355
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+127
-61
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+604
-164
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+293
-30
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+16
-11
example/ck_tile/01_fmha/rotary.hpp
example/ck_tile/01_fmha/rotary.hpp
+84
-0
example/ck_tile/01_fmha/script/benchmark_bwd.sh
example/ck_tile/01_fmha/script/benchmark_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/benchmark_fwd.sh
example/ck_tile/01_fmha/script/benchmark_fwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+94
-41
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+97
-13
include/ck/ck.hpp
include/ck/ck.hpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+2
-2
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
...operation/gpu/element/combined_element_wise_operation.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+11
-4
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+34
-0
No files found.
example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool
run_convnd_fwd_example
(
int
argc
,
char
*
argv
[])
{
print_helper_msg
();
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
ck
::
utils
::
conv
::
ConvParam
conv_param
{
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}};
if
(
argc
==
1
)
{
// use default
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck
::
index_t
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
conv_param
=
ck
::
utils
::
conv
::
parse_conv_param
(
num_dim_spatial
,
5
,
argv
);
}
// instantiate in and wei element ops, will
// instantiate out_element_op below for every iteration
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
run
=
[
&
](
auto
ndim_spatial
,
auto
in_layout
,
auto
wei_layout
,
auto
out_layout
)
{
constexpr
ck
::
index_t
ndim_spatial_value
=
ndim_spatial
.
value
;
using
InLayout
=
decltype
(
in_layout
);
using
WeiLayout
=
decltype
(
wei_layout
);
using
OutLayout
=
decltype
(
out_layout
);
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
const
auto
wei_g_k_c_xs_desc
=
ck
::
utils
::
conv
::
make_weight_host_tensor_descriptor_g_k_c_xs_packed
<
WeiLayout
>
(
conv_param
);
const
auto
out_g_n_k_wos_desc
=
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
return
run_grouped_conv_fwd
<
ndim_spatial_value
,
InDataType
,
WeiDataType
,
ConvOutDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
DeviceGroupedConvNDFwdInstance
<
ndim_spatial_value
,
InLayout
,
WeiLayout
,
OutLayout
>>
(
do_verification
,
init_method
,
time_kernel
,
conv_param
,
in_g_n_c_wis_desc
,
wei_g_k_c_xs_desc
,
out_g_n_k_wos_desc
,
in_element_op
,
wei_element_op
);
};
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
if
(
conv_param
.
num_dim_spatial_
==
1
)
{
return
run
(
ck
::
Number
<
1
>
{},
ctc
::
GNWC
{},
ctc
::
GKXC
{},
ctc
::
GNWK
{});
}
else
if
(
conv_param
.
num_dim_spatial_
==
2
)
{
return
run
(
ck
::
Number
<
2
>
{},
ctc
::
GNHWC
{},
ctc
::
GKYXC
{},
ctc
::
GNHWK
{});
}
else
if
(
conv_param
.
num_dim_spatial_
==
3
)
{
return
run
(
ck
::
Number
<
3
>
{},
ctc
::
GNDHWC
{},
ctc
::
GKZYXC
{},
ctc
::
GNDHWK
{});
}
return
true
;
}
example/ck_tile/01_fmha/CMakeLists.txt
View file @
4885c38a
# generate a list of kernels, but not actually emit files at config stage
# validate user-specified fmha_fwd API list
set
(
FMHA_FWD_KNOWN_APIS
"fwd;fwd_splitkv;fwd_appendkv"
)
set
(
FMHA_FWD_ENABLE_APIS
"fwd"
CACHE STRING
"semicolon-separated list of APIs to generate (
${
FMHA_FWD_KNOWN_APIS
}
) & link, or
\"
all
\"
."
)
if
(
FMHA_FWD_ENABLE_APIS STREQUAL
"all"
)
set
(
FMHA_FWD_ENABLE_APIS
${
FMHA_FWD_KNOWN_APIS
}
)
endif
()
foreach
(
api
${
FMHA_FWD_ENABLE_APIS
}
)
if
(
NOT
"
${
api
}
"
IN_LIST FMHA_FWD_KNOWN_APIS
)
message
(
FATAL_ERROR
"
${
api
}
isn't a known api:
${
FMHA_FWD_KNOWN_APIS
}
."
)
endif
()
endforeach
()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if
(
NOT
"fwd"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND FMHA_FWD_ENABLE_APIS
"fwd"
)
endif
()
string
(
REPLACE
";"
","
FMHA_FWD_APIS
"
${
FMHA_FWD_ENABLE_APIS
}
"
)
# generate a list of kernels, but not actually emit files at config sta
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
--api
${
FMHA_FWD_APIS
}
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
)
execute_process
(
...
...
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command
(
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
--api
${
FMHA_FWD_APIS
}
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
add_custom_command
(
...
...
@@ -60,6 +80,20 @@ else()
endif
()
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if
(
"fwd_splitkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0
)
endif
()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if
(
"fwd_appendkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0
)
endif
()
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
4885c38a
...
...
@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
}
ROPE_MAP
=
{
"no"
:
"ck_tile::RotaryEmbeddingEnum::NONE"
,
"inter"
:
"ck_tile::RotaryEmbeddingEnum::INTERLEAVED"
,
"half"
:
"ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP
=
{
"no"
:
"rope_enum::none"
,
"inter"
:
"rope_enum::interleaved"
,
"half"
:
"rope_enum::half_rotated"
}
MODE_MAP
=
{
"batch"
:
"false"
,
"group"
:
"true"
...
...
@@ -105,4 +117,4 @@ PIPELINE_ENUM_MAP = {
BOOL_MAP
=
{
"t"
:
"true"
,
"f"
:
"false"
}
\ No newline at end of file
}
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
0 → 100644
View file @
4885c38a
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import
copy
from
dataclasses
import
dataclass
import
fnmatch
import
itertools
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Tuple
from
codegen.cmake_config
import
*
from
codegen.cpp_symbol_map
import
*
from
codegen.ops.fmha_fwd
import
(
FmhaFwdApiTrait
,
DTYPE_BITS
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_HDIM_CASE
,
)
FMHA_FWD_APPENDKV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME
=
"fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API
=
"""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@
dataclass
class
FmhaFwdAppendKVApiTrait
:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
bs
:
int
# tile size along q seqlen
bsk
:
int
# tile size along k seqlen
bd
:
int
# tile size along qk gemm unroll
bdv
:
int
# tile size along kv gemm unroll
vlayout
:
str
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
rope
:
str
# key from ROPE_MAP
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
bs
}
-
{
self
.
bsk
}
-
{
self
.
bd
}
-
{
self
.
bdv
}
-
{
self
.
vlayout
}
-'
+
\
f
'
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
-
{
self
.
rope
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bs
}
!= 0*/'
else
:
return
f
'a.seqlen_q %
{
self
.
bs
}
== 0'
@
property
def
skcheck
(
self
)
->
str
:
# we do not check all the values in a.seqlen_k_ptr
return
'true'
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bd
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bd
}
== 0'
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bdv
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bdv
}
== 0'
@
dataclass
class
FmhaFwdAppendKVPipeline
:
F_vlayout
:
str
# row/col
F_spad
:
str
# true/false
F_skpad
:
str
#
F_dpad
:
str
#
F_dvpad
:
str
#
F_rope
:
str
# key from ROPE_MAP
F_pagedkv
:
str
# t/f
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_skpad
==
't'
:
n
+=
'sk'
if
self
.
F_dpad
==
't'
:
n
+=
'd'
if
self
.
F_dvpad
==
't'
:
n
+=
'dv'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
'v
{
self
.
F_vlayout
[
0
]
}
'
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_rope
!=
'no'
:
n
+=
f
'_
{
self
.
F_rope
}
'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
class
FmhaFwdAppendKVApiPool
:
def
__init__
(
self
,
mask_impl
):
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
if
trait
.
hdim
not
in
self
.
pool
[
trait
.
dtype
].
keys
():
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
]
=
list
()
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
@
property
def
api
(
self
)
->
str
:
per_dtypes
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
pool
.
keys
()):
per_hdim_case
=
str
()
for
j
,
hdim
in
enumerate
(
self
.
pool
[
dtype
].
keys
()):
traits
=
self
.
pool
[
dtype
][
hdim
]
inners
=
str
()
for
k
,
trait
in
enumerate
(
traits
):
if_k
=
'if'
if
k
==
0
else
'else if'
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_APPENDKV_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
class
FmhaFwdAppendKVTileSize
:
F_bs
:
int
# tile size along q seqlen
F_bsk
:
int
# tile size along k seqlen
F_bd
:
int
# tile size along qk gemm unroll
F_bdv
:
int
# tile size along kv gemm unroll
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bs
}
x
{
self
.
F_bsk
}
x
{
self
.
F_bd
}
x
{
self
.
F_bdv
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
class
FmhaFwdAppendKVKernel
:
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_tile
:
FmhaFwdAppendKVTileSize
F_pipeline
:
FmhaFwdAppendKVPipeline
mask_impl
:
str
@
property
def
template
(
self
)
->
str
:
kernel_body
=
str
()
return
FMHA_FWD_KERNEL_HEADER
+
\
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bs
=
self
.
F_tile
.
F_bs
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bd
=
self
.
F_tile
.
F_bd
,
F_bdv
=
self
.
F_tile
.
F_bdv
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_rope
=
ROPE_MAP
[
self
.
F_pipeline
.
F_rope
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
)
@
property
def
name
(
self
)
->
str
:
# TODO: we don't encode idx here
return
f
"fmha_fwd_appendkv_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_"
+
\
self
.
F_tile
.
name
+
'_'
+
self
.
F_pipeline
.
name
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdAppendKVApiTrait
:
return
FmhaFwdAppendKVApiTrait
(
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
bs
=
self
.
F_tile
.
F_bs
,
bsk
=
self
.
F_tile
.
F_bsk
,
bd
=
self
.
F_tile
.
F_bd
,
bdv
=
self
.
F_tile
.
F_bdv
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
,
rope
=
self
.
F_pipeline
.
F_rope
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
)
}
else
:
return
None
def
get_fwd_appendkv_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
Tuple
[
FmhaFwdAppendKVApiPool
,
List
[
FmhaFwdAppendKVKernel
]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def
get_pipelines
(
dtype
,
hdim
)
->
List
[
FmhaFwdAppendKVPipeline
]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for
vlayout
in
[
'row'
,
'col'
]:
for
pagedkv
in
[
"t"
,
"f"
]:
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
'f'
,
'f'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'half'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'half'
,
pagedkv
))
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# rope/paged-kv is not supported
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
else
:
assert
False
return
pipelines
gen
=
list
()
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
for
hdim_str
in
d
.
keys
():
tile
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
k
=
FmhaFwdAppendKVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_pipeline
=
pipeline
,
mask_impl
=
mask_impl
)
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
return
(
api_pool
,
gen
)
def
write_single_kernel
(
kernel
:
FmhaFwdAppendKVKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_fwd_appendkv_api
(
api_pool
:
FmhaFwdAppendKVApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_FWD_APPENDKV_API_FILENAME
).
write_text
(
api_pool
.
api
)
def
write_blobs
(
output_dir
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
api_pool
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
write_single_kernel
(
kernel
,
output_dir
)
write_fwd_appendkv_api
(
api_pool
,
output_dir
)
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
with
file_path
.
open
(
'a'
)
as
f
:
_
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
4885c38a
...
...
@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
)
DTYPE_BITS
=
{
"fp32"
:
32
,
"fp16"
:
16
,
"bf16"
:
16
,
"fp8"
:
8
,
"bf8"
:
8
}
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
...
...
@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
false,
{F_lse},
{F_dropout},
{F_squant},
{F_pagedkv},
kHasUnevenSplits,
{F_occupancy}>;
...
...
@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
...
...
@@ -86,7 +93,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
...
...
@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}};
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
}} else {{
kernel_runner<true>::run(s, a);
...
...
@@ -160,7 +172,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
...
...
@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a);
...
...
@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
if(s.log_level_ > 0)
std::cout
...
...
@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
);
}}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_
splitkv_
traits t, fmha_fwd_
splitkv_
args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse})
&& (t.has_dropout == {F_dropout})
&& (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
dropou
t}, {F_
squant
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_
splitkv_
traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
squan
t}, {F_
pagedkv
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
"""
@
dataclass
class
FmhaFwdSplitKVApiTrait
:
pipeline_tag
:
str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
mode
:
str
# value from MODE_MAP
bm0
:
int
# tile size along q seqlen (block size)
bn0
:
int
# tile size along qk seqlen
bk0
:
int
# tile size along qk gemm unroll
bn1
:
int
# tile size along v head_dim
bk1
:
int
# tile size along kv gemm unroll
bk0blen
:
int
vlayout
:
str
mask
:
str
bias
:
str
#
lse
:
str
#
squant
:
str
#
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
bm0
}
-
{
self
.
bn0
}
-
{
self
.
bk0
}
-
{
self
.
bn0
}
-
{
self
.
bk1
}
-
{
self
.
bk0blen
}
-'
+
\
f
'
{
self
.
vlayout
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
lse
}
-
{
self
.
squant
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-'
+
\
f
'
{
self
.
dvpad
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode spad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
spad
==
't'
:
return
'true'
# always support
else
:
return
'true'
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
assert
False
@
property
def
skcheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode skpad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_fp8'
]:
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
assert
False
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
dataclass
class
FmhaFwdSplitKVPipeline
:
tag
:
str
...
...
@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
F_dvpad
:
str
#
F_bias
:
str
# true/false
F_lse
:
str
#
F_dropout
:
str
#
F_squant
:
str
#
F_pagedkv
:
str
# t/f
F_mask
:
str
# value from MASK_MAP
@
property
...
...
@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
@
dataclass
...
...
@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
def
register_traits
(
self
,
trait
:
FmhaFwd
SplitKV
ApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
...
...
@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
dropou
t
=
BOOL_MAP
[
trait
.
dropout
]
,
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
squan
t
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
...
...
@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
...
...
@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
return
FmhaFwdApiTrait
(
def
api_trait
(
self
)
->
FmhaFwd
SplitKV
ApiTrait
:
return
FmhaFwd
SplitKV
ApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
...
...
@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
...
...
@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
return
FmhaFwdApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
mode
=
self
.
F_mode
,
bm0
=
self
.
F_tile
.
F_bm0
,
bn0
=
self
.
F_tile
.
F_bn0
,
bk0
=
self
.
F_tile
.
F_bk0
,
bn1
=
self
.
F_tile
.
F_bn1
,
bk1
=
self
.
F_tile
.
F_bk1
,
bk0blen
=
self
.
F_tile
.
F_bk0blen
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
...
...
@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
# splitkv kernel donot support dropout
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"f"
]):
if
hdim
==
256
:
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]
:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/
dropout
kernels
# no need lse/
paged-kv
kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
squant
,
'f'
,
mask
))
else
:
assert
False
return
pipelines
...
...
@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if
pipeline
.
F_pagedkv
==
't'
:
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
4885c38a
...
...
@@ -4,6 +4,7 @@
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
#include <array>
...
...
@@ -16,6 +17,10 @@
#include <utility>
#include <vector>
#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()"
#endif
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
...
...
@@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[])
"seqlen_q. if group-mode, means the average value of seqlen_q
\n
"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
\n
"
"also with
\"
-s=s0,s1,s2...
\"
comma seperated int to set per batch seqlen(group-mode)"
)
.
insert
(
"s_k"
,
"-1"
,
"seqlen_k, -1 means equal to s"
)
.
insert
(
"s_k"
,
"-1"
,
"seqlen_k (including new key/value), -1 means equal to s"
)
.
insert
(
"s_knew"
,
"0"
,
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly."
)
.
insert
(
"s_kpad"
,
"-1"
,
"seqlen_k stride between 2 tokens, currently used in group-mode only
\n
"
...
...
@@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[])
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"rotary_dim"
,
"0"
,
"RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all"
)
.
insert
(
"rotary_interleaved"
,
"1"
,
"whether to apply interleaved RoPE"
)
.
insert
(
"num_splits"
,
"1"
,
"# of splits for key/value. 0 to determine actual number by heuristic"
)
.
insert
(
"page_block_size"
,
"0"
,
"paged-kvcache block size. 0 means not use paged-kvcahe"
)
.
insert
(
"cache_batch_idx"
,
"0"
,
"whether to use index map to the kvcache"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
...
...
@@ -244,20 +258,6 @@ int override_num_splits_if_necessary(
return
num_splits
;
}
float
fmha_fwd_dispatch
(
fmha_fwd_traits
traits
,
fmha_fwd_args
args
,
const
ck_tile
::
stream_config
&
config
)
{
if
(
1
<
args
.
num_splits
)
{
return
fmha_fwd_splitkv
(
traits
,
args
,
config
);
}
else
{
return
fmha_fwd
(
traits
,
args
,
config
);
}
}
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
...
...
@@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
false
;
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
batch
,
arg_parser
.
get_str
(
"s"
),
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_kpad"
));
std
::
optional
<
uint32_t
>
seed
=
arg_parser
.
get_uint32
(
"seed"
);
if
(
*
seed
==
0
)
{
seed
.
reset
();
}
ck_tile
::
index_t
hdim_q
=
arg_parser
.
get_int
(
"d"
);
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
ck_tile
::
index_t
seqlen_knew
=
arg_parser
.
get_int
(
"s_knew"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API
if
(
seqlen_knew
!=
0
)
{
std
::
cerr
<<
"kvcache is not supported. ignoring the 's_knew' option"
<<
std
::
endl
;
seqlen_knew
=
0
;
}
#endif
if
(
seqlen_knew
<
0
)
{
seqlen_knew
=
randint
<
ck_tile
::
index_t
>
(
1
,
arg_parser
.
get_int
(
"s"
),
seed
);
}
ck_tile
::
index_t
rotary_dim
=
arg_parser
.
get_int
(
"rotary_dim"
);
if
constexpr
(
!
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
DataType
,
ck_tile
::
bf16_t
>
))
{
if
(
0
<
rotary_dim
)
{
std
::
cerr
<<
"rotary embedding is only available for data type=fp16|bf16"
<<
std
::
endl
;
return
false
;
}
}
#if !CK_TILE_FMHA_FWD_APPENDKV_API
else
if
(
0
<
rotary_dim
)
{
std
::
cerr
<<
"rotary embedding is not supported. ignoring the 'rotary_dim' option"
<<
std
::
endl
;
rotary_dim
=
0
;
}
#endif
if
(
!
(
rotary_dim
<=
hdim_q
))
{
std
::
cerr
<<
"rotary_dim should be less than or equal to head dim for q"
<<
std
::
endl
;
return
false
;
}
else
if
(
!
(
rotary_dim
%
16
==
0
))
{
std
::
cerr
<<
"only rotary dimensions divisible by 16 are currently supported"
<<
std
::
endl
;
return
false
;
}
ck_tile
::
index_t
page_block_size
=
arg_parser
.
get_int
(
"page_block_size"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
std
::
cerr
<<
"paged-kvcache is not supported. ignoring the 'page_block_size' option"
<<
std
::
endl
;
page_block_size
=
0
;
}
#endif
if
(
!
(
page_block_size
%
128
==
0
))
{
std
::
cerr
<<
"only paged-kvcache block size divisible by 128 are currently supported"
<<
std
::
endl
;
return
false
;
}
bool
use_cache_batch_idx
=
arg_parser
.
get_bool
(
"cache_batch_idx"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
use_cache_batch_idx
)
{
std
::
cerr
<<
"split-kv is not supported. ignoring the 'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
#endif
if
(
0
<
page_block_size
&&
use_cache_batch_idx
)
{
std
::
cerr
<<
"paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
// the input tensor layout for kvcache is same as batch mode
const
bool
need_append_kvcache
=
(
0
<
seqlen_knew
||
0
<
rotary_dim
);
const
bool
use_kvcache
=
(
need_append_kvcache
||
use_cache_batch_idx
||
0
<
page_block_size
);
if
(
use_kvcache
&&
mode
!=
mode_enum
::
batch
)
{
std
::
cerr
<<
"kvcache enabled. ignoring the 'mode' option"
<<
std
::
endl
;
mode
=
mode_enum
::
batch
;
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
batch
,
arg_parser
.
get_str
(
"s"
),
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_kpad"
),
/*seqlen_k_min=*/
0
<
seqlen_knew
?
seqlen_knew
:
0
,
use_kvcache
);
// compute kvcache seqlen_k (before appending knew/vnew)
auto
cache_seqlen_ks
=
seqlen_ks
;
std
::
transform
(
cache_seqlen_ks
.
begin
(),
cache_seqlen_ks
.
end
(),
cache_seqlen_ks
.
begin
(),
[
&
](
auto
seqlen_k
)
{
return
seqlen_k
-
seqlen_knew
;
});
#if 0
// clang-format off
...
...
@@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
#endif
ck_tile
::
index_t
hdim_q
=
arg_parser
.
get_int
(
"d"
);
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead * hdim
...
...
@@ -356,14 +454,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
s_randval
=
true
;
}
std
::
string
init_method
=
arg_parser
.
get_str
(
"init"
);
std
::
optional
<
uint32_t
>
seed
=
arg_parser
.
get_uint32
(
"seed"
);
if
(
*
seed
==
0
)
std
::
string
init_method
=
arg_parser
.
get_str
(
"init"
);
const
bool
is_rotary_interleaved
=
arg_parser
.
get_bool
(
"rotary_interleaved"
);
ck_tile
::
index_t
num_splits
=
arg_parser
.
get_int
(
"num_splits"
);
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
num_splits
!=
1
)
{
seed
.
reset
();
std
::
cerr
<<
"split-kv is not supported. ignoring the 'num_splits' option"
<<
std
::
endl
;
num_splits
=
1
;
}
int
num_splits
=
arg_parser
.
get_int
(
"num_splits"
);
#endif
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
...
...
@@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
const
ck_tile
::
index_t
max_num_page_blocks
=
(
0
<
page_block_size
?
batch
*
std
::
max
(
1
,
ck_tile
::
integer_divide_ceil
(
max_seqlen_k
,
page_block_size
))
:
0
);
// legalize num_splits according to other options
if
(
num_splits
<
1
)
{
...
...
@@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cerr
<<
"num_splits greater than 128 is not supported"
<<
std
::
endl
;
return
false
;
}
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
p_drop
&&
(
1
<
num_splits
||
use_kvcache
))
{
std
::
cerr
<<
"dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
<<
std
::
endl
;
p_drop
=
0.0
f
;
}
#endif
auto
get_lengths
=
[
&
](
bool
permute
,
ck_tile
::
index_t
b
/*batch*/
,
...
...
@@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
ck_tile
::
HostTensor
<
KDataType
>
k_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
0
<
page_block_size
?
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_q
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile
::
HostTensor
<
KDataType
>
knew_host
(
0
<
seqlen_knew
?
get_lengths
(
i_perm
,
batch
,
nhead_k
,
seqlen_knew
,
hdim_q
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
VDataType
>
v_host
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
));
0
<
page_block_size
?
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_v
)
:
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
hdim_v
,
page_block_size
))
:
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
)));
ck_tile
::
HostTensor
<
VDataType
>
vnew_host
(
0
<
seqlen_knew
?
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
batch
,
nhead_k
,
seqlen_knew
,
hdim_v
)
:
get_lengths
(
i_perm
,
batch
,
nhead_k
,
hdim_v
,
seqlen_knew
))
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host
(
bias
.
type
==
bias_enum
::
elementwise_bias
?
get_lengths
(
i_perm
,
1
,
1
,
shape_seqlen_q
,
shape_seqlen_k
)
...
...
@@ -478,12 +608,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
nhead
})
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
auto
[
rotary_cos_host
,
rotary_sin_host
]
=
generate_rotary_cos_sin
<
KDataType
>
(
std
::
max
(
shape_seqlen_q
,
shape_seqlen_k
),
rotary_dim
,
seed
);
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
1
<
num_splits
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
shape_batch
,
nhead
,
shape_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
...
...
@@ -500,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
p_drop
>
0
?
get_lengths
(
true
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
int32_t
>
block_table_host
(
0
<
page_block_size
?
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
max_num_page_blocks
/
batch
}
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
ck_tile
::
HostTensor
<
int32_t
>
cache_batch_idx_host
(
use_cache_batch_idx
?
std
::
array
<
ck_tile
::
index_t
,
1
>
{
batch
}
:
std
::
array
<
ck_tile
::
index_t
,
1
>
{
1
});
if
(
init_method
==
"ui"
||
init_method
==
"0"
)
{
ck_tile
::
FillUniformDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"ni"
)
{
ck_tile
::
FillNormalDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"uf"
||
init_method
==
"1"
)
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
0.
f
,
1.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
0.
f
,
1.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
0.
f
,
1.
f
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
0.
f
,
1.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
0.
f
,
1.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
0.
f
,
1.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"nf"
)
{
ck_tile
::
FillNormalDistribution
<
QDataType
>
{
0.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistribution
<
KDataType
>
{
0.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistribution
<
KDataType
>
{
0.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillNormalDistribution
<
VDataType
>
{
0.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistribution
<
VDataType
>
{
0.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillNormalDistribution
<
BiasDataType
>
{
0.
f
,
3.
f
,
seed
}(
bias_host
);
}
else
if
(
init_method
==
"tf"
||
init_method
==
"2"
)
{
ck_tile
::
FillTrigValue
<
QDataType
>
{}(
q_host
);
ck_tile
::
FillTrigValue
<
KDataType
>
{}(
k_host
);
ck_tile
::
FillTrigValue
<
KDataType
>
{}(
knew_host
);
ck_tile
::
FillTrigValue
<
VDataType
>
{}(
v_host
);
ck_tile
::
FillTrigValue
<
VDataType
>
{}(
vnew_host
);
ck_tile
::
FillTrigValue
<
BiasDataType
>
{}(
bias_host
);
}
else
if
(
init_method
==
"ufq"
||
init_method
==
"uf:q"
||
...
...
@@ -540,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
vnew_host
);
// bias_fp8 = qscale_bias * bias_fp32
float
qscale_bias
=
(
dtype_max
/
range_q
)
*
(
dtype_max
/
range_k
);
...
...
@@ -550,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
bias
.
type
==
bias_enum
::
alibi
)
{
auto
slopes
=
ck_tile
::
get_alibi_slopes
<
SaccDataType
>
(
nhead
);
assert
(
slopes
.
size
()
==
nhead
);
assert
(
slopes
.
size
()
==
static_cast
<
std
::
size_t
>
(
nhead
)
)
;
if
(
bias
.
rank_info
==
0
)
{
// alibi in 1*h
...
...
@@ -565,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
iota_shuffle
(
block_table_host
.
begin
(),
block_table_host
.
end
(),
0
);
iota_shuffle
(
cache_batch_idx_host
.
begin
(),
cache_batch_idx_host
.
end
(),
0
);
ck_tile
::
DeviceMem
q_buf
(
q_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
k_buf
(
k_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
knew_buf
(
knew_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
v_buf
(
v_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
vnew_buf
(
vnew_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
bias_buf
(
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
lse_acc_buf
(
lse_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_acc_buf
(
o_acc_host
.
get_element_space_size_in_bytes
());
...
...
@@ -576,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
seqlen_kpads
[
0
]
<
0
?
0
:
seqlen_ks
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
cache_seqlen_k_buf
(
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_sin_buf
(
rotary_sin_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
block_table_buf
(
block_table_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
cache_batch_idx_buf
(
cache_batch_idx_host
.
get_element_space_size_in_bytes
());
q_buf
.
ToDevice
(
q_host
.
data
());
k_buf
.
ToDevice
(
k_host
.
data
());
knew_buf
.
ToDevice
(
knew_host
.
data
());
v_buf
.
ToDevice
(
v_host
.
data
());
vnew_buf
.
ToDevice
(
vnew_host
.
data
());
bias_buf
.
ToDevice
(
bias_host
.
data
());
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_k
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
data
()
:
seqstart_k_with_padding_host
.
data
());
seqlen_k_buf
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
nullptr
:
seqlen_ks
.
data
());
seqlen_k_buf
.
ToDevice
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
block_table_buf
.
ToDevice
(
block_table_host
.
data
());
cache_batch_idx_buf
.
ToDevice
(
cache_batch_idx_host
.
data
());
// clang-format off
auto
layout_str
=
[
&
](
bool
permute
){
if
(
permute
)
return
std
::
string
(
"bhsd"
);
if
(
permute
)
return
std
::
string
(
"bhsd"
);
else
return
std
::
string
(
"bshd"
);
};
auto
io_layout
=
[
&
](
bool
iperm_
,
bool
operm_
)
{
if
(
iperm_
==
operm_
)
return
layout_str
(
iperm_
);
if
(
iperm_
==
operm_
)
return
layout_str
(
iperm_
);
else
return
layout_str
(
iperm_
)
+
std
::
string
(
"-"
)
+
layout_str
(
operm_
);
};
// clang-format on
...
...
@@ -609,51 +780,77 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale_s:"
<<
scale_s
<<
", bias:"
<<
bias
<<
", p_drop:"
<<
p_drop
<<
", lse:"
<<
lse
<<
", squant:"
<<
squant
<<
", mask:"
<<
mask
<<
", v:"
<<
vlayout
;
#if CK_TILE_FMHA_FWD_APPENDKV_API
if
(
0
<
rotary_dim
)
{
std
::
cout
<<
", rotary_dim:"
<<
rotary_dim
<<
"("
<<
(
is_rotary_interleaved
?
"inter"
:
"half"
)
<<
")"
;
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
1
<
num_splits
)
{
std
::
cout
<<
", num_splits:"
<<
num_splits
;
}
if
(
0
<
page_block_size
)
{
std
::
cout
<<
", page_block_size:"
<<
page_block_size
;
}
if
(
use_cache_batch_idx
)
{
std
::
cout
<<
", cache_batch_idx:"
<<
use_cache_batch_idx
;
}
#endif
std
::
cout
<<
std
::
flush
;
auto
fmha_traits
=
fmha_fwd_traits
{
hdim_q
,
hdim_v
,
data_type
,
mode
==
mode_enum
::
group
,
is_v_rowmajor
,
mask
.
type
,
bias
.
type
,
lse
,
p_drop
>
0.0
f
,
squant
};
const
auto
init_traits
=
[
&
](
auto
&
traits
)
{
traits
.
hdim_q
=
hdim_q
;
traits
.
hdim_v
=
hdim_v
;
traits
.
data_type
=
data_type
;
traits
.
is_v_rowmajor
=
is_v_rowmajor
;
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
scales
{
scale_p
};
else
return
ck_tile
::
identity
{};
}();
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_appendkv_traits
,
std
::
decay_t
<
decltype
(
traits
)
>>
)
{
traits
.
rope_type
=
(
0
<
rotary_dim
?
(
is_rotary_interleaved
?
rope_enum
::
interleaved
:
rope_enum
::
half_rotated
)
:
rope_enum
::
none
);
}
else
// fmha_fwd_traits or fmha_splitkv_traits
{
traits
.
is_group_mode
=
(
mode
==
mode_enum
::
group
);
traits
.
mask_type
=
mask
.
type
;
traits
.
bias_type
=
bias
.
type
;
traits
.
has_lse
=
lse
;
traits
.
do_fp8_static_quant
=
squant
;
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
ck_tile
::
scales
{
scale_o
});
else
return
ck_tile
::
identity
{};
}();
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_traits
,
std
::
decay_t
<
decltype
(
traits
)
>>
)
{
traits
.
has_dropout
=
(
p_drop
>
0.0
f
);
}
}
};
auto
fmha
_args
=
[
&
,
k_paddings_
=
seqlen_kpads
]()
{
const
auto
init
_args
=
[
&
,
k_paddings_
=
seqlen_kpads
](
auto
&
args
)
{
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
// setup stride_* arguments
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_v
=
[
&
]()
{
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_knew
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
else
return
0
<
page_block_size
?
(
i_perm
?
page_block_size
:
nhead_k
*
page_block_size
)
:
(
i_perm
?
shape_seqlen_k
:
nhead_k
*
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
else
return
i_perm
?
shape_
seqlen_k
:
nhead_k
*
shape_
seqlen_k
;
return
i_perm
?
seqlen_k
new
:
nhead_k
*
seqlen_k
new
;
}();
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
...
...
@@ -661,12 +858,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_v
=
[
&
]()
{
const
ck_tile
::
index_t
nhead_stride_k
=
(
0
<
page_block_size
?
(
i_perm
?
page_block_size
*
hdim_q
:
hdim_q
)
:
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
));
const
ck_tile
::
index_t
nhead_stride_knew
=
(
i_perm
?
seqlen_knew
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
;
return
0
<
page_block_size
?
(
i_perm
?
page_block_size
*
hdim_v
:
hdim_v
)
:
(
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
);
else
return
i_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
;
return
0
<
page_block_size
?
(
i_perm
?
hdim_v
*
page_block_size
:
page_block_size
)
:
(
i_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
nhead_stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
seqlen_knew
*
hdim_v
:
hdim_v
;
else
return
i_perm
?
hdim_v
*
seqlen_knew
:
seqlen_knew
;
}();
const
ck_tile
::
index_t
nhead_stride_bias
=
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
...
...
@@ -676,88 +884,194 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_o_acc
=
(
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
// setup batch_stride_* arguments
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_k
=
(
nhead_k
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_v
=
(
nhead_k
*
hdim_v
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_k
=
(
0
<
page_block_size
?
(
nhead_k
*
page_block_size
*
hdim_q
)
:
(
nhead_k
*
shape_seqlen_k
*
hdim_q
));
const
ck_tile
::
index_t
batch_stride_knew
=
(
nhead_k
*
seqlen_knew
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_v
=
(
0
<
page_block_size
?
(
nhead_k
*
hdim_v
*
page_block_size
)
:
(
nhead_k
*
hdim_v
*
shape_seqlen_k
));
const
ck_tile
::
index_t
batch_stride_vnew
=
(
nhead_k
*
hdim_v
*
seqlen_knew
);
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_block_table
=
(
max_num_page_blocks
/
batch
);
// setup split_stride_* arguments (only used in split-kv kernel)
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max_seqlen_q
*
hdim_v
);
return
fmha_fwd_args
{
q_buf
.
GetDeviceBuffer
(),
k_buf
.
GetDeviceBuffer
(),
v_buf
.
GetDeviceBuffer
(),
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
:
bias_buf
.
GetDeviceBuffer
(),
randval_buf
.
GetDeviceBuffer
(),
lse_acc_buf
.
GetDeviceBuffer
(),
o_acc_buf
.
GetDeviceBuffer
(),
lse_buf
.
GetDeviceBuffer
(),
o_buf
.
GetDeviceBuffer
(),
seqstart_q
.
GetDeviceBuffer
(),
seqstart_k
.
GetDeviceBuffer
(),
k_paddings_
[
0
]
<
0
?
nullptr
:
seqlen_k_buf
.
GetDeviceBuffer
(),
shape_seqlen_q
,
shape_seqlen_k
,
batch
,
max_seqlen_q
,
hdim_q
,
hdim_v
,
nhead
,
nhead_k
,
num_splits
,
scale_s
,
scale_p
,
scale_o
,
stride_q
,
stride_k
,
stride_v
,
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
,
stride_randval
,
stride_o_acc
,
stride_o
,
nhead_stride_q
,
nhead_stride_k
,
nhead_stride_v
,
nhead_stride_bias
,
nhead_stride_randval
,
nhead_stride_lse
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_bias
,
batch_stride_randval
,
batch_stride_lse
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o
,
split_stride_lse_acc
,
split_stride_o_acc
,
mask
.
left
,
mask
.
right
,
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
s_randval
,
{
drop_seed
,
drop_offset
}};
args
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
args
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
args
.
v_ptr
=
v_buf
.
GetDeviceBuffer
();
args
.
batch
=
batch
;
args
.
seqlen_q
=
shape_seqlen_q
;
// unused in group mode
args
.
hdim_q
=
hdim_q
;
args
.
hdim_v
=
hdim_v
;
args
.
nhead_q
=
nhead
;
args
.
nhead_k
=
nhead_k
;
args
.
stride_q
=
stride_q
;
args
.
stride_k
=
stride_k
;
args
.
stride_v
=
stride_v
;
args
.
nhead_stride_q
=
nhead_stride_q
;
args
.
nhead_stride_k
=
nhead_stride_k
;
args
.
nhead_stride_v
=
nhead_stride_v
;
args
.
batch_stride_q
=
batch_stride_q
;
args
.
batch_stride_k
=
batch_stride_k
;
args
.
batch_stride_v
=
batch_stride_v
;
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_appendkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
knew_ptr
=
knew_buf
.
GetDeviceBuffer
();
args
.
vnew_ptr
=
vnew_buf
.
GetDeviceBuffer
();
args
.
seqlen_knew
=
seqlen_knew
;
args
.
seqlen_k_ptr
=
cache_seqlen_k_buf
.
GetDeviceBuffer
();
args
.
rotary_cos_ptr
=
(
0
<
rotary_dim
?
rotary_cos_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
rotary_sin_ptr
=
(
0
<
rotary_dim
?
rotary_sin_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
rotary_dim
=
rotary_dim
;
args
.
has_mask
=
(
mask
.
type
!=
mask_enum
::
no_mask
);
args
.
block_table_ptr
=
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
page_block_size
=
page_block_size
;
args
.
cache_batch_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
stride_knew
=
stride_knew
;
args
.
stride_vnew
=
stride_vnew
;
args
.
nhead_stride_knew
=
nhead_stride_knew
;
args
.
nhead_stride_vnew
=
nhead_stride_vnew
;
args
.
batch_stride_knew
=
batch_stride_knew
;
args
.
batch_stride_vnew
=
batch_stride_vnew
;
}
else
// fmha_fwd_args or fmha_fwd_splitkv_args
{
args
.
bias_ptr
=
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
:
bias_buf
.
GetDeviceBuffer
();
args
.
lse_ptr
=
lse_buf
.
GetDeviceBuffer
();
args
.
o_ptr
=
o_buf
.
GetDeviceBuffer
();
args
.
seqstart_q_ptr
=
(
mode
==
mode_enum
::
group
?
seqstart_q
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqstart_k_ptr
=
(
mode
==
mode_enum
::
group
?
seqstart_k
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k_ptr
=
(
use_kvcache
||
0
<=
k_paddings_
[
0
]
?
seqlen_k_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k
=
shape_seqlen_k
;
// unused in group mode (or kvcache enabled)
args
.
max_seqlen_q
=
max_seqlen_q
;
args
.
scale_s
=
scale_s
;
args
.
scale_p
=
scale_p
;
args
.
scale_o
=
scale_o
;
args
.
stride_bias
=
(
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
);
args
.
stride_o
=
stride_o
;
args
.
nhead_stride_bias
=
nhead_stride_bias
;
args
.
nhead_stride_lse
=
nhead_stride_lse
;
args
.
nhead_stride_o
=
nhead_stride_o
;
args
.
batch_stride_bias
=
batch_stride_bias
;
args
.
batch_stride_lse
=
batch_stride_lse
;
args
.
batch_stride_o
=
batch_stride_o
;
args
.
window_size_left
=
mask
.
left
;
args
.
window_size_right
=
mask
.
right
;
args
.
mask_type
=
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
);
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
rand_val_ptr
=
randval_buf
.
GetDeviceBuffer
();
args
.
stride_randval
=
stride_randval
;
args
.
nhead_stride_randval
=
nhead_stride_randval
;
args
.
batch_stride_randval
=
batch_stride_randval
;
args
.
p_drop
=
p_drop
;
args
.
s_randval
=
s_randval
;
args
.
drop_seed_offset
=
std
::
tie
(
drop_seed
,
drop_offset
);
}
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
args
.
block_table_ptr
=
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
page_block_size
=
page_block_size
;
args
.
cache_batch_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
num_splits
=
num_splits
;
args
.
stride_o_acc
=
stride_o_acc
;
args
.
nhead_stride_lse_acc
=
nhead_stride_lse_acc
;
args
.
nhead_stride_o_acc
=
nhead_stride_o_acc
;
args
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
}
}
};
const
float
appendkv_ave_time
=
[
&
]
{
#if CK_TILE_FMHA_FWD_APPENDKV_API
if
(
need_append_kvcache
)
{
fmha_fwd_appendkv_traits
fwd_appendkv_traits
;
init_traits
(
fwd_appendkv_traits
);
fmha_fwd_appendkv_args
fwd_appendkv_args
;
init_args
(
fwd_appendkv_args
);
return
fmha_fwd_appendkv
(
fwd_appendkv_traits
,
fwd_appendkv_args
,
stream_config
);
}
#endif
return
0.0
f
;
}();
float
ave_time
=
fmha_fwd_dispatch
(
fmha_traits
,
fmha_args
,
stream_config
);
const
float
fwd_ave_time
=
[
&
]
{
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
1
<
num_splits
||
use_kvcache
)
{
fmha_fwd_splitkv_traits
fmha_splitkv_traits
;
init_traits
(
fmha_splitkv_traits
);
fmha_fwd_splitkv_args
fmha_splitkv_args
;
init_args
(
fmha_splitkv_args
);
return
fmha_fwd_splitkv
(
fmha_splitkv_traits
,
fmha_splitkv_args
,
stream_config
);
}
#endif
fmha_fwd_traits
fmha_traits
;
init_traits
(
fmha_traits
);
fmha_fwd_args
fmha_args
;
init_args
(
fmha_args
);
if
(
ave_time
<
0
)
return
fmha_fwd
(
fmha_traits
,
fmha_args
,
stream_config
);
}();
if
(
appendkv_ave_time
<
0.0
f
||
fwd_ave_time
<
0.0
f
)
{
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
return
false
;
}
const
float
ave_time
=
(
appendkv_ave_time
+
fwd_ave_time
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
...
...
@@ -775,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf
.
FromDevice
(
o_host
.
data
());
lse_buf
.
FromDevice
(
lse_host
.
data
());
randval_buf
.
FromDevice
(
randval_host
.
data
());
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
scales
{
scale_p
};
else
return
ck_tile
::
identity
{};
}();
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
ck_tile
::
scales
{
scale_o
});
else
return
ck_tile
::
identity
{};
}();
float
p_undrop
=
1.0
-
p_drop
;
uint8_t
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
float
rp_undrop
=
1.0
/
p_undrop
;
bool
pass
=
true
;
for
(
ck_tile
::
index_t
wb
=
0
;
wb
<
batch
;
++
wb
)
{
const
ck_tile
::
index_t
real_seqlen_q
=
seqstart_q_host
[
wb
+
1
]
-
seqstart_q_host
[
wb
];
const
ck_tile
::
index_t
real_seqlen_k
=
seqstart_k_host
[
wb
+
1
]
-
seqstart_k_host
[
wb
];
// adjust matrix index according to the mode
const
ck_tile
::
index_t
b
=
(
mode
==
mode_enum
::
batch
?
wb
:
0
);
const
ck_tile
::
index_t
b_idx
=
(
mode
==
mode_enum
::
batch
?
wb
:
0
);
const
ck_tile
::
index_t
cache_b_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_host
(
b_idx
)
:
b_idx
);
const
ck_tile
::
index_t
query_offset
=
(
mode
==
mode_enum
::
batch
?
0
:
seqstart_q_host
[
wb
]);
const
ck_tile
::
index_t
key_offset
=
(
mode
==
mode_enum
::
batch
?
0
:
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
[
wb
]
:
seqstart_k_with_padding_host
[
wb
]));
const
auto
v_host_ref_lengths
=
std
::
array
<
ck_tile
::
index_t
,
3
>
{
nhead
,
hdim_v
,
real_seqlen_k
};
const
auto
v_host_ref_strides
=
is_v_rowmajor
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
hdim_v
*
real_seqlen_k
,
1
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
hdim_v
*
real_seqlen_k
,
real_seqlen_k
,
1
};
ck_tile
::
HostTensor
<
QDataType
>
q_host_ref
({
nhead
,
real_seqlen_q
,
hdim_q
});
ck_tile
::
HostTensor
<
KDataType
>
k_host_ref
({
nhead
,
real_seqlen_k
,
hdim_q
});
ck_tile
::
HostTensor
<
VDataType
>
v_host_ref
(
v_host_ref_lengths
,
v_host_ref_strides
);
ck_tile
::
HostTensor
<
VDataType
>
v_host_ref
(
{
nhead
,
hdim_v
,
real_seqlen_k
}
);
ck_tile
::
HostTensor
<
ODataType
>
o_host_ref
({
nhead
,
real_seqlen_q
,
hdim_v
});
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
s_host_ref
({
nhead
,
real_seqlen_q
,
real_seqlen_k
});
...
...
@@ -815,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off
// permute
if
(
i_perm
)
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b
,
i
[
0
],
i
[
1
]
+
query_offset
,
i
[
2
]);
});
else
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b
,
i
[
1
]
+
query_offset
,
i
[
0
],
i
[
2
]);
});
if
(
i_perm
)
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b_idx
,
i
[
0
],
i
[
1
]
+
query_offset
,
i
[
2
]);
});
else
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b_idx
,
i
[
1
]
+
query_offset
,
i
[
0
],
i
[
2
]);
});
#if CK_TILE_FMHA_FWD_APPENDKV_API
// optionally apply RoPE to the q_host_ref
if
(
0
<
rotary_dim
)
{
decltype
(
q_host_ref
)
q_host_ref_ro
(
q_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
real_seqlen_q
);
ck_tile
::
reference_batched_rotary_position_embedding
(
q_host_ref
,
rotary_cos_slice
,
rotary_sin_slice
,
is_rotary_interleaved
,
q_host_ref_ro
,
/*use_1_row_sin_cos=*/
mask
.
type
==
mask_enum
::
no_mask
);
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
b
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
b
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host_ref_ro
(
i
);
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
i_perm
)
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
]
%
page_block_size
,
i
[
2
]);
});
}
else
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
1
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
}
else
#endif
{
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Knew to the end of K
if
(
0
<
seqlen_knew
)
{
ck_tile
::
HostTensor
<
KDataType
>
knew_host_ref
({
nhead
,
seqlen_knew
,
hdim_q
});
if
(
i_perm
)
knew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
knew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]);
});
else
knew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
knew_host
(
wb
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]);
});
// optionally apply RoPE to the knew_host_ref
auto
*
real_knew_host_ref
=
&
knew_host_ref
;
std
::
optional
<
decltype
(
knew_host_ref
)
>
knew_host_ref_ro
;
if
(
0
<
rotary_dim
)
{
knew_host_ref_ro
.
emplace
(
knew_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
seqlen_knew
);
ck_tile
::
reference_batched_rotary_position_embedding
(
knew_host_ref
,
rotary_cos_slice
,
rotary_sin_slice
,
is_rotary_interleaved
,
knew_host_ref_ro
.
value
());
real_knew_host_ref
=
&
knew_host_ref_ro
.
value
();
}
if
(
is_v_rowmajor
)
{
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
2
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
1
]);
});
(
*
real_knew_host_ref
).
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
k_host_ref
(
i
[
0
],
i
[
1
]
+
cache_seqlen_ks
[
wb
],
i
[
2
])
=
self
(
i
);
});
}
else
{
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
);
});
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
});
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
2
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
}
else
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
%
page_block_size
);
});
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
);
});
}
}
}
else
#endif
{
if
(
is_v_rowmajor
)
{
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
2
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
else
{
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
);
});
}
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Vnew to the end of V
if
(
0
<
seqlen_knew
)
{
ck_tile
::
HostTensor
<
VDataType
>
vnew_host_ref
({
nhead
,
hdim_v
,
seqlen_knew
});
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
2
],
i
[
1
]);
});
else
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
2
],
i
[
0
]
/
nr
,
i
[
1
]);
});
}
else
{
if
(
i_perm
)
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]);
});
else
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]);
});
}
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
v_host_ref
(
i
[
0
],
i
[
1
],
i
[
2
]
+
cache_seqlen_ks
[
wb
])
=
self
(
i
);
});
}
#endif
// clang-format on
// reference
...
...
@@ -959,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
RandValOutputDataType
>
randval_host_ref
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
randval_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
randval_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
self
(
idx
)
=
randval_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
ck_tile
::
reference_batched_dropout
(
p_host_ref
,
randval_host_ref
,
p_undrop_in_uint8_t
,
rp_undrop
);
...
...
@@ -976,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
ODataType
>
o_host_result
({
nhead
,
real_seqlen_q
,
hdim_v
});
// clang-format off
// permute
if
(
o_perm
)
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
if
(
o_perm
)
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
_idx
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
init_method
);
...
...
@@ -999,7 +1439,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
lse_host_result
({
nhead
,
real_seqlen_q
});
lse_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
);
self
(
idx
)
=
lse_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
);
});
cur_pass
=
ck_tile
::
check_err
(
lse_host_result
,
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
4885c38a
...
...
@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits>
template
<
typename
DataType
>
...
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
rand_val_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
// only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
struct
fmha_fwd_splitkv_args
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
lse_acc_ptr
;
void
*
o_acc_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
...
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
num_splits
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
...
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
struct
fmha_fwd_appendkv_args
{
void
*
q_ptr
;
void
*
k_ptr
;
const
void
*
knew_ptr
;
void
*
v_ptr
;
const
void
*
vnew_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_knew
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
const
void
*
rotary_cos_ptr
;
// only used if 'rotary_dim' > 0
const
void
*
rotary_sin_ptr
;
// only used if 'rotary_dim' > 0
ck_tile
::
index_t
rotary_dim
;
bool
has_mask
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_knew
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_vnew
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_knew
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_vnew
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_knew
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_vnew
;
};
template
<
typename
FmhaKernel
>
...
...
@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
...
...
@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
...
...
@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
mask_type
);
}
else
{
// create batch mode kernel arguments
...
...
@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
scale_s
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
args
.
mask_type
);
}
}();
...
...
@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
}
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
...
...
@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
template
<
typename
Kernel
>
auto
fmha_fwd_appendkv_create_kargs_and_grids
(
fmha_fwd_appendkv_args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
knew_ptr
,
args
.
v_ptr
,
args
.
vnew_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k_ptr
,
args
.
seqlen_knew
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
rotary_cos_ptr
,
args
.
rotary_sin_ptr
,
args
.
rotary_dim
,
args
.
has_mask
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_knew
,
args
.
stride_v
,
args
.
stride_vnew
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_knew
,
args
.
nhead_stride_v
,
args
.
nhead_stride_vnew
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_knew
,
args
.
batch_stride_v
,
args
.
batch_stride_vnew
);
dim3
grids
=
Kernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
seqlen_q
,
args
.
seqlen_knew
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
...
...
@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
template
<
typename
Traits_
>
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN0_
,
ck_tile
::
index_t
kK0_
,
ck_tile
::
index_t
kN1_
,
ck_tile
::
index_t
kK1_
,
ck_tile
::
index_t
kK0BlockLength_
,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
struct
fmha_fwd_splitkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kK1
=
kK1_
;
static
constexpr
ck_tile
::
index_t
kK0BlockLength
=
kK0BlockLength_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_get_name_
();
...
...
@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
};
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
ck_tile
::
index_t
kTileSizeS_
,
ck_tile
::
index_t
kTileSizeSk_
,
ck_tile
::
index_t
kTileSizeD_
,
ck_tile
::
index_t
kTileSizeDv_
,
bool
kIsVLayoutRowMajor_
,
bool
kPadS_
,
bool
kPadSk_
,
bool
kPadD_
,
bool
kPadDv_
,
ck_tile
::
RotaryEmbeddingEnum
RotaryEnum_
,
bool
kIsPagedKV_
>
struct
fmha_fwd_appendkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
ck_tile
::
index_t
kTileSizeS
=
kTileSizeS_
;
static
constexpr
ck_tile
::
index_t
kTileSizeSk
=
kTileSizeSk_
;
static
constexpr
ck_tile
::
index_t
kTileSizeD
=
kTileSizeD_
;
static
constexpr
ck_tile
::
index_t
kTileSizeDv
=
kTileSizeDv_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSk
=
kPadSk_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
auto
RotaryEnum
=
RotaryEnum_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
float
fmha_fwd_appendkv_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_appendkv_args
);
// This is the public API, will be generated by script
struct
fmha_fwd_traits
{
...
...
@@ -508,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
};
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd_splitkv
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_splitkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_group_mode
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
float
fmha_fwd_splitkv
(
fmha_fwd_splitkv_traits
,
fmha_fwd_splitkv_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_appendkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_v_rowmajor
;
rope_enum
rope_type
;
};
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/generate.py
View file @
4885c38a
...
...
@@ -5,25 +5,30 @@
import
argparse
from
enum
import
IntEnum
from
pathlib
import
Path
import
pkgutil
import
sys
from
typing
import
List
,
Optional
import
codegen.ops
from
codegen.cmake_config
import
*
from
codegen.ops
import
(
fmha_fwd
,
fmha_fwd_splitkv
,
fmha_bwd
)
class
HandlerId
(
IntEnum
):
LIST_BLOBS
=
0
WRITE_BLOBS
=
1
handlers
=
{
'fwd'
:
(
fmha_fwd
.
list_blobs
,
fmha_fwd
.
write_blobs
),
'fwd_splitkv'
:
(
fmha_fwd_splitkv
.
list_blobs
,
fmha_fwd_splitkv
.
write_blobs
),
'bwd'
:
(
fmha_bwd
.
list_blobs
,
fmha_bwd
.
write_blobs
),
}
# inspect all modules under 'codegen.ops' and register API handlers
ops
=
[]
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
(
codegen
.
ops
.
__path__
):
full_module_name
=
'%s.%s'
%
(
codegen
.
ops
.
__name__
,
module_name
)
if
full_module_name
not
in
sys
.
modules
:
ops
.
append
(
importer
.
find_spec
(
module_name
).
loader
.
load_module
(
module_name
))
unwanted_prefix
=
'fmha_'
handlers
=
dict
(
[(
op
.
__name__
[
len
(
unwanted_prefix
):]
if
op
.
__name__
.
startswith
(
unwanted_prefix
)
else
op
.
__name__
,
(
op
.
list_blobs
,
op
.
write_blobs
))
for
op
in
ops
]
)
assert
0
<
len
(
handlers
)
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
if
output_dir
is
None
:
...
...
@@ -103,4 +108,4 @@ if __name__ == "__main__":
if
args
.
list_blobs
is
not
None
:
list_blobs
(
args
.
list_blobs
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
else
:
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
\ No newline at end of file
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
example/ck_tile/01_fmha/rotary.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum
class
rope_enum
{
none
=
0
,
interleaved
=
1
,
half_rotated
=
2
,
};
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
generate_rotary_cos_sin
(
ck_tile
::
index_t
seqlen
,
ck_tile
::
index_t
rotary_dim
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
// return dummy tensors if we won't apply RoPE at all
if
(
rotary_dim
<=
0
)
{
ck_tile
::
HostTensor
<
DataType
>
dummy
({
1
,
1
});
return
std
::
make_tuple
(
dummy
,
dummy
);
}
std
::
mt19937
random_engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
generator
(
0.0
f
,
1.0
f
);
const
ck_tile
::
index_t
num_rows
=
seqlen
*
2
;
const
ck_tile
::
index_t
num_cols
=
rotary_dim
/
2
;
using
std
::
begin
,
std
::
end
;
ck_tile
::
HostTensor
<
float
>
angle
({
num_rows
,
num_cols
});
std
::
generate
(
begin
(
angle
),
end
(
angle
),
[
&
]
{
return
generator
(
random_engine
)
*
2
*
M_PI
;
});
ck_tile
::
HostTensor
<
DataType
>
cos
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
cos
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
cos
(
origin_value
));
});
ck_tile
::
HostTensor
<
DataType
>
sin
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
sin
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
sin
(
origin_value
));
});
return
std
::
make_tuple
(
cos
,
sin
);
}
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
slice_rotary_cos_sin
(
const
ck_tile
::
HostTensor
<
DataType
>&
cos
,
const
ck_tile
::
HostTensor
<
DataType
>&
sin
,
ck_tile
::
index_t
seqlen_offset
,
ck_tile
::
index_t
seqlen
)
{
assert
(
cos
.
get_num_of_dimension
()
==
2
&&
sin
.
get_num_of_dimension
()
==
2
);
assert
(
cos
.
get_length
(
0
)
==
sin
.
get_length
(
0
)
&&
cos
.
get_length
(
1
)
==
sin
.
get_length
(
1
));
assert
(
static_cast
<
std
::
size_t
>
(
seqlen_offset
+
seqlen
)
<=
cos
.
get_length
(
0
));
const
ck_tile
::
index_t
num_rows
=
seqlen
;
const
ck_tile
::
index_t
num_cols
=
cos
.
get_length
(
1
);
ck_tile
::
HostTensor
<
DataType
>
cos_pt
({
num_rows
,
num_cols
});
cos_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
cos
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
ck_tile
::
HostTensor
<
DataType
>
sin_pt
({
num_rows
,
num_cols
});
sin_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
sin
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
return
std
::
make_tuple
(
cos_pt
,
sin_pt
);
}
example/ck_tile/01_fmha/script/benchmark_bwd.sh
View file @
4885c38a
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/benchmark_fwd.sh
View file @
4885c38a
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
4885c38a
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
...
...
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
View file @
4885c38a
#!/bin/sh
# TODO: run this script from CK root
BUILD
=
build
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
#!/bin/bash
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
...
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
done
done
done
TEST_SPLITKV
=
0
TEST_APPENDKV
=
0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while
getopts
":sa"
opt
;
do
case
"
${
opt
}
"
in
s
)
TEST_SPLITKV
=
1
;;
a
)
TEST_APPENDKV
=
1
;;
*
)
;;
esac
done
run_fp16_bf16_tests
()
{
local
NUM_SPLITS
=(
1
)
local
PAGE_BLOCK_SIZE
=(
0
)
local
CACHE_BATCH_IDX
=(
0
)
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
set
+x
if
[
$TEST_SPLITKV
-eq
1
]
;
then
NUM_SPLITS+
=(
2 3
)
PAGE_BLOCK_SIZE+
=(
128
)
CACHE_BATCH_IDX+
=(
1
)
fi
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
for
num_splits
in
"
${
NUM_SPLITS
[@]
}
"
;
do
for
page_block_size
in
"
${
PAGE_BLOCK_SIZE
[@]
}
"
;
do
for
cache_batch_idx
in
"
${
CACHE_BATCH_IDX
[@]
}
"
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
;
done
;
done
done
;
}
run_fp8_tests
()
{
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
}
run_fp16_appendkv_tests
()
{
for
s
in
$(
seq
63 1 65
)
;
do
for
s_k
in
65 129
;
do
for
s_knew
in
0 64
$s_k
;
do
for
hdim
in
32 64 128 256
;
do
for
ri
in
0 1
;
do
for
rdim
in
0 16 32
$hdim
;
do
for
page_block_size
in
0 128
;
do
for
cache_batch_idx
in
0 1
;
do
$EXE
-prec
=
fp16
-b
=
3
-h
=
3
-d
=
$hdim
-s
=
$s
-s_k
=
$s_k
-s_knew
=
$s_knew
-rotary_dim
=
$rdim
-rotary_interleaved
=
$ri
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-iperm
=
1
-operm
=
1
-kname
=
1
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
}
set
-x
run_fp16_bf16_tests
run_fp8_tests
if
[
$TEST_APPENDKV
-eq
1
]
;
then
run_fp16_appendkv_tests
fi
set
+x
\ No newline at end of file
example/ck_tile/01_fmha/utils.hpp
View file @
4885c38a
...
...
@@ -3,15 +3,17 @@
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp"
...
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
assert
(
0
<
count
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
seqlen_max
>
0
?
(
seqlen_avg
<
seqlen_max
?
seqlen_avg
:
seqlen_max
)
:
seqlen_avg
);
seqlen_min
=
(
0
<
seqlen_min
?
seqlen_min
:
1
);
seqlen_max
=
(
0
<
seqlen_max
?
seqlen_max
:
std
::
numeric_limits
<
int32_t
>::
max
());
assert
(
seqlen_min
<=
seqlen_max
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
...
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is
always greater than 0
if
(
seqlens
[
to_decrease
]
==
1
)
// make sure each elements of seqlens is
in range [seqlen_min, seqlen_max]
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
{
continue
;
}
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
if
(
seqlen_max
>
0
&&
seqlens
[
to_increase
]
>=
seqlen_max
)
if
(
seqlens
[
to_increase
]
>=
seqlen_max
)
{
continue
;
}
...
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_max
,
seed
));
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_min
,
seqlen_max
,
seed
));
}
// return random integer generated uniformly in range [low, high]
template
<
typename
Int
=
int
>
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>
,
Int
>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
return
dist
(
engine
);
}
// return random integers generated uniformly in range [low, high]
template
<
typename
Int
,
typename
ForwardIterator
>
auto
randints
(
ForwardIterator
first
,
ForwardIterator
last
,
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
std
::
generate
(
first
,
last
,
[
&
]
{
return
dist
(
engine
);
});
}
/*
...
...
@@ -112,16 +144,45 @@ decode_seqlen(mode_enum mode,
std
::
string
q_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
use_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if
(
mode
==
mode_enum
::
batch
)
{
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
k
<
0
?
q
:
k
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
[
&
]
{
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
use_kvcache
)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
seqlen_ks
.
end
(),
seqlen_k_min
,
seqlen_k_max
,
seed
);
return
seqlen_ks
;
}
return
seqlen_ks
;
}();
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
else
...
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
s_q
.
push_back
(
q
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_kpad
.
push_back
(
kp
);
// s_k should be greater than or equal to seqlen_k_min
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
idx
++
;
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
{
...
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
}
if
(
idx
<
batch
)
{
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
...
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
r
=
std
::
atoi
(
v
);
return
r
;
}
template
<
typename
RandomAccessIterator
,
typename
Int
>
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
iota_shuffle
(
RandomAccessIterator
first
,
RandomAccessIterator
last
,
Int
value
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
std
::
iota
(
first
,
last
,
value
);
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
shuffle
(
first
,
last
,
engine
);
}
include/ck/ck.hpp
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set
stochastic rounding
as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION
1
// set
rounding to nearest even
as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION
0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
4885c38a
...
...
@@ -67,8 +67,8 @@ struct BlockwiseGemmXdlops_pipeline_base
KPerBlock
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
A
_K1
,
B
_K1
,
A
BlockTransferSrcScalarPerVector
,
B
BlockTransferSrcScalarPerVector
,
A_K1
,
B_K1
,
MRepeat
,
...
...
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
View file @
4885c38a
...
...
@@ -3,7 +3,6 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
...
...
@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
UnaryOp2
unary_op2_
{};
};
using
ScaleScalePass
=
UnaryCombinedOp
<
Scale
,
Scale
,
PassThrough
>
;
using
ScaleScaleRelu
=
UnaryCombinedOp
<
Scale
,
Scale
,
Relu
>
;
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
4885c38a
...
...
@@ -752,11 +752,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
// if constexpr(BBlockLdsExtraN)
// {
// return make_naive_tensor_descriptor(
// make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
// make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
// }
// else
if
constexpr
(
BBlockLdsExtraN
&&
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
K
PerBlock
+
BBlockLdsExtraN
>
{},
I1
));
make_tuple
(
BK1Number
*
Number
<
N
PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{}
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
...
...
@@ -1318,9 +1325,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
...
...
include/ck_tile/core/config.hpp
View file @
4885c38a
...
...
@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...
...
@@ -156,6 +157,14 @@
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
4885c38a
...
...
@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard
=
0
,
// rtn
truncate_with_nan
,
truncate
,
standard_asm
,
};
template
<
bf16_rounding_mode
rounding
=
...
...
@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return
uint16_t
(
u
.
int32
>>
16
);
}
CK_TILE_HOST
constexpr
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
return
float_to_bf16_rtn_raw
(
f
);
}
CK_TILE_DEVICE
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
static
constexpr
uint32_t
FP32_NAN
=
0x7fff0000
;
static
constexpr
uint32_t
ROUND_BIAS_FOR_BF16
=
0x7fff
;
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
check_nan
;
uint32_t
tmp
;
asm
volatile
(
"
\n
\
v_cmp_u_f32 %0, %2, %2
\n
\
v_bfe_u32 %1, %2, 16, 1
\n
\
v_add3_u32 %1, %2, %1, %3
\n
\
v_cndmask_b32 %2, %1, %4, %0
\n
\
v_lshrrev_b32 %2, 16, %2
\n
\
"
:
"=s"
(
check_nan
),
"+v"
(
tmp
),
"+v"
(
u
.
fp32
)
:
"v"
(
ROUND_BIAS_FOR_BF16
),
"v"
(
FP32_NAN
));
return
uint16_t
(
u
.
int32
);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
...
...
@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
return
float_to_bf16_rtn_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard_asm
)
return
float_to_bf16_rtn_asm
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment