Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4885c38a
Commit
4885c38a
authored
Sep 03, 2024
by
aska-0096
Browse files
Merge branch 'transpose_opt' of
https://github.com/ROCm/composable_kernel
into rowwise_opt
parents
cbf14ee1
7c8e92fa
Changes
83
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1886 additions
and
343 deletions
+1886
-343
example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc
..._convnd_activ/convscale_reduce/run_convnd_fwd_example.inc
+98
-0
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+37
-3
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+13
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+355
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+127
-61
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+604
-164
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+293
-30
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+16
-11
example/ck_tile/01_fmha/rotary.hpp
example/ck_tile/01_fmha/rotary.hpp
+84
-0
example/ck_tile/01_fmha/script/benchmark_bwd.sh
example/ck_tile/01_fmha/script/benchmark_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/benchmark_fwd.sh
example/ck_tile/01_fmha/script/benchmark_fwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+94
-41
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+97
-13
include/ck/ck.hpp
include/ck/ck.hpp
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+2
-2
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
...operation/gpu/element/combined_element_wise_operation.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+11
-4
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+34
-0
No files found.
example/62_convnd_activ/convscale_reduce/run_convnd_fwd_example.inc
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool
run_convnd_fwd_example
(
int
argc
,
char
*
argv
[])
{
print_helper_msg
();
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
ck
::
utils
::
conv
::
ConvParam
conv_param
{
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}};
if
(
argc
==
1
)
{
// use default
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck
::
index_t
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
conv_param
=
ck
::
utils
::
conv
::
parse_conv_param
(
num_dim_spatial
,
5
,
argv
);
}
// instantiate in and wei element ops, will
// instantiate out_element_op below for every iteration
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
run
=
[
&
](
auto
ndim_spatial
,
auto
in_layout
,
auto
wei_layout
,
auto
out_layout
)
{
constexpr
ck
::
index_t
ndim_spatial_value
=
ndim_spatial
.
value
;
using
InLayout
=
decltype
(
in_layout
);
using
WeiLayout
=
decltype
(
wei_layout
);
using
OutLayout
=
decltype
(
out_layout
);
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
const
auto
wei_g_k_c_xs_desc
=
ck
::
utils
::
conv
::
make_weight_host_tensor_descriptor_g_k_c_xs_packed
<
WeiLayout
>
(
conv_param
);
const
auto
out_g_n_k_wos_desc
=
ck
::
utils
::
conv
::
make_output_host_tensor_descriptor_g_n_k_wos_packed
<
OutLayout
>
(
conv_param
);
return
run_grouped_conv_fwd
<
ndim_spatial_value
,
InDataType
,
WeiDataType
,
ConvOutDataType
,
OutDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
DeviceGroupedConvNDFwdInstance
<
ndim_spatial_value
,
InLayout
,
WeiLayout
,
OutLayout
>>
(
do_verification
,
init_method
,
time_kernel
,
conv_param
,
in_g_n_c_wis_desc
,
wei_g_k_c_xs_desc
,
out_g_n_k_wos_desc
,
in_element_op
,
wei_element_op
);
};
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
if
(
conv_param
.
num_dim_spatial_
==
1
)
{
return
run
(
ck
::
Number
<
1
>
{},
ctc
::
GNWC
{},
ctc
::
GKXC
{},
ctc
::
GNWK
{});
}
else
if
(
conv_param
.
num_dim_spatial_
==
2
)
{
return
run
(
ck
::
Number
<
2
>
{},
ctc
::
GNHWC
{},
ctc
::
GKYXC
{},
ctc
::
GNHWK
{});
}
else
if
(
conv_param
.
num_dim_spatial_
==
3
)
{
return
run
(
ck
::
Number
<
3
>
{},
ctc
::
GNDHWC
{},
ctc
::
GKZYXC
{},
ctc
::
GNDHWK
{});
}
return
true
;
}
example/ck_tile/01_fmha/CMakeLists.txt
View file @
4885c38a
# generate a list of kernels, but not actually emit files at config stage
# validate user-specified fmha_fwd API list
set
(
FMHA_FWD_KNOWN_APIS
"fwd;fwd_splitkv;fwd_appendkv"
)
set
(
FMHA_FWD_ENABLE_APIS
"fwd"
CACHE STRING
"semicolon-separated list of APIs to generate (
${
FMHA_FWD_KNOWN_APIS
}
) & link, or
\"
all
\"
."
)
if
(
FMHA_FWD_ENABLE_APIS STREQUAL
"all"
)
set
(
FMHA_FWD_ENABLE_APIS
${
FMHA_FWD_KNOWN_APIS
}
)
endif
()
foreach
(
api
${
FMHA_FWD_ENABLE_APIS
}
)
if
(
NOT
"
${
api
}
"
IN_LIST FMHA_FWD_KNOWN_APIS
)
message
(
FATAL_ERROR
"
${
api
}
isn't a known api:
${
FMHA_FWD_KNOWN_APIS
}
."
)
endif
()
endforeach
()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if
(
NOT
"fwd"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND FMHA_FWD_ENABLE_APIS
"fwd"
)
endif
()
string
(
REPLACE
";"
","
FMHA_FWD_APIS
"
${
FMHA_FWD_ENABLE_APIS
}
"
)
# generate a list of kernels, but not actually emit files at config sta
execute_process
(
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
--api
${
FMHA_FWD_APIS
}
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
)
)
execute_process
(
execute_process
(
...
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
...
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command
(
add_custom_command
(
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
--api
${
FMHA_FWD_APIS
}
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
)
add_custom_command
(
add_custom_command
(
...
@@ -60,6 +80,20 @@ else()
...
@@ -60,6 +80,20 @@ else()
endif
()
endif
()
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if
(
"fwd_splitkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0
)
endif
()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if
(
"fwd_appendkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0
)
endif
()
# Allow comparing floating points directly in order to check sentinel values
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
4885c38a
...
@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
...
@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
}
}
ROPE_MAP
=
{
"no"
:
"ck_tile::RotaryEmbeddingEnum::NONE"
,
"inter"
:
"ck_tile::RotaryEmbeddingEnum::INTERLEAVED"
,
"half"
:
"ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP
=
{
"no"
:
"rope_enum::none"
,
"inter"
:
"rope_enum::interleaved"
,
"half"
:
"rope_enum::half_rotated"
}
MODE_MAP
=
{
MODE_MAP
=
{
"batch"
:
"false"
,
"batch"
:
"false"
,
"group"
:
"true"
"group"
:
"true"
...
@@ -105,4 +117,4 @@ PIPELINE_ENUM_MAP = {
...
@@ -105,4 +117,4 @@ PIPELINE_ENUM_MAP = {
BOOL_MAP
=
{
BOOL_MAP
=
{
"t"
:
"true"
,
"t"
:
"true"
,
"f"
:
"false"
"f"
:
"false"
}
}
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
0 → 100644
View file @
4885c38a
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import
copy
from
dataclasses
import
dataclass
import
fnmatch
import
itertools
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Tuple
from
codegen.cmake_config
import
*
from
codegen.cpp_symbol_map
import
*
from
codegen.ops.fmha_fwd
import
(
FmhaFwdApiTrait
,
DTYPE_BITS
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_HDIM_CASE
,
)
FMHA_FWD_APPENDKV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME
=
"fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API
=
"""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@
dataclass
class
FmhaFwdAppendKVApiTrait
:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
bs
:
int
# tile size along q seqlen
bsk
:
int
# tile size along k seqlen
bd
:
int
# tile size along qk gemm unroll
bdv
:
int
# tile size along kv gemm unroll
vlayout
:
str
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
rope
:
str
# key from ROPE_MAP
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
bs
}
-
{
self
.
bsk
}
-
{
self
.
bd
}
-
{
self
.
bdv
}
-
{
self
.
vlayout
}
-'
+
\
f
'
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
-
{
self
.
rope
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bs
}
!= 0*/'
else
:
return
f
'a.seqlen_q %
{
self
.
bs
}
== 0'
@
property
def
skcheck
(
self
)
->
str
:
# we do not check all the values in a.seqlen_k_ptr
return
'true'
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bd
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bd
}
== 0'
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bdv
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bdv
}
== 0'
@
dataclass
class
FmhaFwdAppendKVPipeline
:
F_vlayout
:
str
# row/col
F_spad
:
str
# true/false
F_skpad
:
str
#
F_dpad
:
str
#
F_dvpad
:
str
#
F_rope
:
str
# key from ROPE_MAP
F_pagedkv
:
str
# t/f
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_skpad
==
't'
:
n
+=
'sk'
if
self
.
F_dpad
==
't'
:
n
+=
'd'
if
self
.
F_dvpad
==
't'
:
n
+=
'dv'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
'v
{
self
.
F_vlayout
[
0
]
}
'
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_rope
!=
'no'
:
n
+=
f
'_
{
self
.
F_rope
}
'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
class
FmhaFwdAppendKVApiPool
:
def
__init__
(
self
,
mask_impl
):
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
if
trait
.
hdim
not
in
self
.
pool
[
trait
.
dtype
].
keys
():
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
]
=
list
()
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
@
property
def
api
(
self
)
->
str
:
per_dtypes
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
pool
.
keys
()):
per_hdim_case
=
str
()
for
j
,
hdim
in
enumerate
(
self
.
pool
[
dtype
].
keys
()):
traits
=
self
.
pool
[
dtype
][
hdim
]
inners
=
str
()
for
k
,
trait
in
enumerate
(
traits
):
if_k
=
'if'
if
k
==
0
else
'else if'
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_APPENDKV_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
class
FmhaFwdAppendKVTileSize
:
F_bs
:
int
# tile size along q seqlen
F_bsk
:
int
# tile size along k seqlen
F_bd
:
int
# tile size along qk gemm unroll
F_bdv
:
int
# tile size along kv gemm unroll
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bs
}
x
{
self
.
F_bsk
}
x
{
self
.
F_bd
}
x
{
self
.
F_bdv
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
class
FmhaFwdAppendKVKernel
:
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_tile
:
FmhaFwdAppendKVTileSize
F_pipeline
:
FmhaFwdAppendKVPipeline
mask_impl
:
str
@
property
def
template
(
self
)
->
str
:
kernel_body
=
str
()
return
FMHA_FWD_KERNEL_HEADER
+
\
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bs
=
self
.
F_tile
.
F_bs
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bd
=
self
.
F_tile
.
F_bd
,
F_bdv
=
self
.
F_tile
.
F_bdv
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_rope
=
ROPE_MAP
[
self
.
F_pipeline
.
F_rope
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
)
@
property
def
name
(
self
)
->
str
:
# TODO: we don't encode idx here
return
f
"fmha_fwd_appendkv_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_"
+
\
self
.
F_tile
.
name
+
'_'
+
self
.
F_pipeline
.
name
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdAppendKVApiTrait
:
return
FmhaFwdAppendKVApiTrait
(
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
bs
=
self
.
F_tile
.
F_bs
,
bsk
=
self
.
F_tile
.
F_bsk
,
bd
=
self
.
F_tile
.
F_bd
,
bdv
=
self
.
F_tile
.
F_bdv
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
,
rope
=
self
.
F_pipeline
.
F_rope
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
)
}
else
:
return
None
def
get_fwd_appendkv_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
Tuple
[
FmhaFwdAppendKVApiPool
,
List
[
FmhaFwdAppendKVKernel
]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def
get_pipelines
(
dtype
,
hdim
)
->
List
[
FmhaFwdAppendKVPipeline
]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for
vlayout
in
[
'row'
,
'col'
]:
for
pagedkv
in
[
"t"
,
"f"
]:
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
'f'
,
'f'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'half'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'half'
,
pagedkv
))
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# rope/paged-kv is not supported
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
else
:
assert
False
return
pipelines
gen
=
list
()
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
for
hdim_str
in
d
.
keys
():
tile
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
k
=
FmhaFwdAppendKVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_pipeline
=
pipeline
,
mask_impl
=
mask_impl
)
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
return
(
api_pool
,
gen
)
def
write_single_kernel
(
kernel
:
FmhaFwdAppendKVKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_fwd_appendkv_api
(
api_pool
:
FmhaFwdAppendKVApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_FWD_APPENDKV_API_FILENAME
).
write_text
(
api_pool
.
api
)
def
write_blobs
(
output_dir
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
api_pool
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
write_single_kernel
(
kernel
,
output_dir
)
write_fwd_appendkv_api
(
api_pool
,
output_dir
)
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
with
file_path
.
open
(
'a'
)
as
f
:
_
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
4885c38a
...
@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
...
@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
)
)
DTYPE_BITS
=
{
"fp32"
:
32
,
"fp16"
:
16
,
"bf16"
:
16
,
"fp8"
:
8
,
"bf8"
:
8
}
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
...
@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
{F_bias},
false,
false,
{F_lse},
{F_lse},
{F_dropout},
{F_squant},
{F_squant},
{F_pagedkv},
kHasUnevenSplits,
kHasUnevenSplits,
{F_occupancy}>;
{F_occupancy}>;
...
@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
...
@@ -86,7 +93,7 @@ using fmha_kernel =
...
@@ -86,7 +93,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_pipeline,
fmha_epilogue>;
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
using k_ = fmha_kernel;
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
...
@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
...
@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}};
}};
}}
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
#include <iostream>
template<>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if constexpr({F_mode} == false) {{ // batch mode
if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
kernel_runner<false>::run(s, a);
}} else {{
}} else {{
kernel_runner<true>::run(s, a);
kernel_runner<true>::run(s, a);
...
@@ -160,7 +172,7 @@ using fmha_kernel =
...
@@ -160,7 +172,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_pipeline,
fmha_epilogue>;
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
using k_ = fmha_kernel;
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
...
@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
...
@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream>
#include <iostream>
template<>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if (a.num_splits <= 16) {{
if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a);
kernel_runner<4>::run(s, a);
...
@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
...
@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream>
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout
std::cout
...
@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
...
@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
);
);
}}
}}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_
splitkv_
traits t, fmha_fwd_
splitkv_
args a, const ck_tile::stream_config& s){{
float r = -1;
float r = -1;
{F_dispatch}
{F_dispatch}
return r;
return r;
}}
}}
"""
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse})
&& (t.has_dropout == {F_dropout})
&& (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
dropou
t}, {F_
squant
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_
splitkv_
traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
squan
t}, {F_
pagedkv
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
"""
"""
@
dataclass
class
FmhaFwdSplitKVApiTrait
:
pipeline_tag
:
str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
mode
:
str
# value from MODE_MAP
bm0
:
int
# tile size along q seqlen (block size)
bn0
:
int
# tile size along qk seqlen
bk0
:
int
# tile size along qk gemm unroll
bn1
:
int
# tile size along v head_dim
bk1
:
int
# tile size along kv gemm unroll
bk0blen
:
int
vlayout
:
str
mask
:
str
bias
:
str
#
lse
:
str
#
squant
:
str
#
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
bm0
}
-
{
self
.
bn0
}
-
{
self
.
bk0
}
-
{
self
.
bn0
}
-
{
self
.
bk1
}
-
{
self
.
bk0blen
}
-'
+
\
f
'
{
self
.
vlayout
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
lse
}
-
{
self
.
squant
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-'
+
\
f
'
{
self
.
dvpad
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode spad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
spad
==
't'
:
return
'true'
# always support
else
:
return
'true'
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
assert
False
@
property
def
skcheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode skpad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_fp8'
]:
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
assert
False
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
dataclass
@
dataclass
class
FmhaFwdSplitKVPipeline
:
class
FmhaFwdSplitKVPipeline
:
tag
:
str
tag
:
str
...
@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
...
@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
F_dvpad
:
str
#
F_dvpad
:
str
#
F_bias
:
str
# true/false
F_bias
:
str
# true/false
F_lse
:
str
#
F_lse
:
str
#
F_dropout
:
str
#
F_squant
:
str
#
F_squant
:
str
#
F_pagedkv
:
str
# t/f
F_mask
:
str
# value from MASK_MAP
F_mask
:
str
# value from MASK_MAP
@
property
@
property
...
@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
...
@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
else
:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
return
n
@
dataclass
@
dataclass
...
@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
...
@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
self
.
pool
=
dict
()
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
def
register_traits
(
self
,
trait
:
FmhaFwd
SplitKV
ApiTrait
)
->
None
:
# TODO: do we need to check duplication?
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
self
.
pool
[
trait
.
dtype
]
=
dict
()
...
@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
...
@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
dropou
t
=
BOOL_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
squan
t
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
...
@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
...
@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
def
api_trait
(
self
)
->
FmhaFwd
SplitKV
ApiTrait
:
return
FmhaFwdApiTrait
(
return
FmhaFwd
SplitKV
ApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
dtype
=
self
.
F_dtype
,
...
@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
mask
=
self
.
F_pipeline
.
F_mask
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
squant
=
self
.
F_pipeline
.
F_squant
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
,
spad
=
self
.
F_pipeline
.
F_spad
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
...
@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
...
@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
return
FmhaFwdApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
mode
=
self
.
F_mode
,
bm0
=
self
.
F_tile
.
F_bm0
,
bn0
=
self
.
F_tile
.
F_bn0
,
bk0
=
self
.
F_tile
.
F_bk0
,
bn1
=
self
.
F_tile
.
F_bn1
,
bk1
=
self
.
F_tile
.
F_bk1
,
bk0blen
=
self
.
F_tile
.
F_bk0blen
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
)
# TODO: design a more practical way to do it
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
# this is current supported tile size per hdim
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
...
@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
# splitkv kernel donot support dropout
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"f"
]):
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
:
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]
:
# if True:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
else
:
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/
dropout
kernels
# no need lse/
paged-kv
kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
squant
,
'f'
,
mask
))
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
if
pipeline
.
F_pagedkv
==
't'
:
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k
=
Kernel
(
F_idx
=
0
,
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_dtype
=
dtype
,
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
4885c38a
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
4885c38a
...
@@ -5,10 +5,13 @@
...
@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"
#include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits>
#include <type_traits>
template
<
typename
DataType
>
template
<
typename
DataType
>
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
const
void
*
v_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
rand_val_ptr
;
void
*
rand_val_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
// only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
struct
fmha_fwd_splitkv_args
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
lse_acc_ptr
;
void
*
lse_acc_ptr
;
void
*
o_acc_ptr
;
void
*
o_acc_ptr
;
void
*
lse_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
void
*
o_ptr
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
batch
;
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
num_splits
;
ck_tile
::
index_t
num_splits
;
float
scale_s
;
float
scale_s
;
float
scale_p
;
float
scale_p
;
float
scale_o
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
};
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
struct
fmha_fwd_appendkv_args
{
void
*
q_ptr
;
void
*
k_ptr
;
const
void
*
knew_ptr
;
void
*
v_ptr
;
const
void
*
vnew_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_knew
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
const
void
*
rotary_cos_ptr
;
// only used if 'rotary_dim' > 0
const
void
*
rotary_sin_ptr
;
// only used if 'rotary_dim' > 0
ck_tile
::
index_t
rotary_dim
;
bool
has_mask
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_knew
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_vnew
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_knew
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_vnew
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_knew
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_vnew
;
};
};
template
<
typename
FmhaKernel
>
template
<
typename
FmhaKernel
>
...
@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}
template
<
typename
Kernel
>
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
auto
kargs
=
[
&
]
{
...
@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
seqlen_k_ptr
,
...
@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
);
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
...
@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
num_splits
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
scale_s
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
);
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}
}();
}();
...
@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
}
}
template
<
typename
Kernel
>
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
auto
kargs
=
[
&
]
{
...
@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
}
template
<
typename
Kernel
>
auto
fmha_fwd_appendkv_create_kargs_and_grids
(
fmha_fwd_appendkv_args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
knew_ptr
,
args
.
v_ptr
,
args
.
vnew_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k_ptr
,
args
.
seqlen_knew
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
rotary_cos_ptr
,
args
.
rotary_sin_ptr
,
args
.
rotary_dim
,
args
.
has_mask
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_knew
,
args
.
stride_v
,
args
.
stride_vnew
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_knew
,
args
.
nhead_stride_v
,
args
.
nhead_stride_vnew
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_knew
,
args
.
batch_stride_v
,
args
.
batch_stride_vnew
);
dim3
grids
=
Kernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
seqlen_q
,
args
.
seqlen_knew
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
typename
DataType_
,
...
@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
...
@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN0_
,
ck_tile
::
index_t
kK0_
,
ck_tile
::
index_t
kN1_
,
ck_tile
::
index_t
kK1_
,
ck_tile
::
index_t
kK0BlockLength_
,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
struct
fmha_fwd_splitkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kK1
=
kK1_
;
static
constexpr
ck_tile
::
index_t
kK0BlockLength
=
kK0BlockLength_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_get_name_
();
std
::
string
fmha_fwd_splitkv_get_name_
();
...
@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
...
@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
};
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
ck_tile
::
index_t
kTileSizeS_
,
ck_tile
::
index_t
kTileSizeSk_
,
ck_tile
::
index_t
kTileSizeD_
,
ck_tile
::
index_t
kTileSizeDv_
,
bool
kIsVLayoutRowMajor_
,
bool
kPadS_
,
bool
kPadSk_
,
bool
kPadD_
,
bool
kPadDv_
,
ck_tile
::
RotaryEmbeddingEnum
RotaryEnum_
,
bool
kIsPagedKV_
>
struct
fmha_fwd_appendkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
ck_tile
::
index_t
kTileSizeS
=
kTileSizeS_
;
static
constexpr
ck_tile
::
index_t
kTileSizeSk
=
kTileSizeSk_
;
static
constexpr
ck_tile
::
index_t
kTileSizeD
=
kTileSizeD_
;
static
constexpr
ck_tile
::
index_t
kTileSizeDv
=
kTileSizeDv_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSk
=
kPadSk_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
auto
RotaryEnum
=
RotaryEnum_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
float
fmha_fwd_appendkv_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_appendkv_args
);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
fmha_fwd_traits
struct
fmha_fwd_traits
{
{
...
@@ -508,4 +743,32 @@ struct fmha_fwd_traits
...
@@ -508,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd_splitkv
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_splitkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_group_mode
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
float
fmha_fwd_splitkv
(
fmha_fwd_splitkv_traits
,
fmha_fwd_splitkv_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_appendkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_v_rowmajor
;
rope_enum
rope_type
;
};
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/generate.py
View file @
4885c38a
...
@@ -5,25 +5,30 @@
...
@@ -5,25 +5,30 @@
import
argparse
import
argparse
from
enum
import
IntEnum
from
enum
import
IntEnum
from
pathlib
import
Path
from
pathlib
import
Path
import
pkgutil
import
sys
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
codegen.ops
from
codegen.cmake_config
import
*
from
codegen.cmake_config
import
*
from
codegen.ops
import
(
fmha_fwd
,
fmha_fwd_splitkv
,
fmha_bwd
)
class
HandlerId
(
IntEnum
):
class
HandlerId
(
IntEnum
):
LIST_BLOBS
=
0
LIST_BLOBS
=
0
WRITE_BLOBS
=
1
WRITE_BLOBS
=
1
handlers
=
{
# inspect all modules under 'codegen.ops' and register API handlers
'fwd'
:
(
fmha_fwd
.
list_blobs
,
fmha_fwd
.
write_blobs
),
ops
=
[]
'fwd_splitkv'
:
(
fmha_fwd_splitkv
.
list_blobs
,
fmha_fwd_splitkv
.
write_blobs
),
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
(
codegen
.
ops
.
__path__
):
'bwd'
:
(
fmha_bwd
.
list_blobs
,
fmha_bwd
.
write_blobs
),
full_module_name
=
'%s.%s'
%
(
codegen
.
ops
.
__name__
,
module_name
)
}
if
full_module_name
not
in
sys
.
modules
:
ops
.
append
(
importer
.
find_spec
(
module_name
).
loader
.
load_module
(
module_name
))
unwanted_prefix
=
'fmha_'
handlers
=
dict
(
[(
op
.
__name__
[
len
(
unwanted_prefix
):]
if
op
.
__name__
.
startswith
(
unwanted_prefix
)
else
op
.
__name__
,
(
op
.
list_blobs
,
op
.
write_blobs
))
for
op
in
ops
]
)
assert
0
<
len
(
handlers
)
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
if
output_dir
is
None
:
if
output_dir
is
None
:
...
@@ -103,4 +108,4 @@ if __name__ == "__main__":
...
@@ -103,4 +108,4 @@ if __name__ == "__main__":
if
args
.
list_blobs
is
not
None
:
if
args
.
list_blobs
is
not
None
:
list_blobs
(
args
.
list_blobs
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
list_blobs
(
args
.
list_blobs
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
else
:
else
:
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
\ No newline at end of file
example/ck_tile/01_fmha/rotary.hpp
0 → 100644
View file @
4885c38a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum
class
rope_enum
{
none
=
0
,
interleaved
=
1
,
half_rotated
=
2
,
};
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
generate_rotary_cos_sin
(
ck_tile
::
index_t
seqlen
,
ck_tile
::
index_t
rotary_dim
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
// return dummy tensors if we won't apply RoPE at all
if
(
rotary_dim
<=
0
)
{
ck_tile
::
HostTensor
<
DataType
>
dummy
({
1
,
1
});
return
std
::
make_tuple
(
dummy
,
dummy
);
}
std
::
mt19937
random_engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
generator
(
0.0
f
,
1.0
f
);
const
ck_tile
::
index_t
num_rows
=
seqlen
*
2
;
const
ck_tile
::
index_t
num_cols
=
rotary_dim
/
2
;
using
std
::
begin
,
std
::
end
;
ck_tile
::
HostTensor
<
float
>
angle
({
num_rows
,
num_cols
});
std
::
generate
(
begin
(
angle
),
end
(
angle
),
[
&
]
{
return
generator
(
random_engine
)
*
2
*
M_PI
;
});
ck_tile
::
HostTensor
<
DataType
>
cos
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
cos
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
cos
(
origin_value
));
});
ck_tile
::
HostTensor
<
DataType
>
sin
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
sin
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
sin
(
origin_value
));
});
return
std
::
make_tuple
(
cos
,
sin
);
}
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
slice_rotary_cos_sin
(
const
ck_tile
::
HostTensor
<
DataType
>&
cos
,
const
ck_tile
::
HostTensor
<
DataType
>&
sin
,
ck_tile
::
index_t
seqlen_offset
,
ck_tile
::
index_t
seqlen
)
{
assert
(
cos
.
get_num_of_dimension
()
==
2
&&
sin
.
get_num_of_dimension
()
==
2
);
assert
(
cos
.
get_length
(
0
)
==
sin
.
get_length
(
0
)
&&
cos
.
get_length
(
1
)
==
sin
.
get_length
(
1
));
assert
(
static_cast
<
std
::
size_t
>
(
seqlen_offset
+
seqlen
)
<=
cos
.
get_length
(
0
));
const
ck_tile
::
index_t
num_rows
=
seqlen
;
const
ck_tile
::
index_t
num_cols
=
cos
.
get_length
(
1
);
ck_tile
::
HostTensor
<
DataType
>
cos_pt
({
num_rows
,
num_cols
});
cos_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
cos
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
ck_tile
::
HostTensor
<
DataType
>
sin_pt
({
num_rows
,
num_cols
});
sin_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
sin
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
return
std
::
make_tuple
(
cos_pt
,
sin_pt
);
}
example/ck_tile/01_fmha/script/benchmark_bwd.sh
View file @
4885c38a
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
VALID
=
0
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/benchmark_fwd.sh
View file @
4885c38a
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
VALID
=
0
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
4885c38a
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
KNAME
=
1
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_WARMUP
=
0
...
...
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
View file @
4885c38a
#!/bin/sh
#!/bin/bash
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
KNAME
=
1
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_WARMUP
=
0
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
# mode=0
# mode=0
# export HIP_VISIBLE_DEVICES=4
# export HIP_VISIBLE_DEVICES=4
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
done
TEST_SPLITKV
=
0
done
TEST_APPENDKV
=
0
done
# options:
done
# -s: run splitkv tests
done
# -a: run appendkv tests
done
while
getopts
":sa"
opt
;
do
done
case
"
${
opt
}
"
in
s
)
TEST_SPLITKV
=
1
;;
a
)
TEST_APPENDKV
=
1
;;
*
)
;;
esac
done
done
run_fp16_bf16_tests
()
{
local
NUM_SPLITS
=(
1
)
local
PAGE_BLOCK_SIZE
=(
0
)
local
CACHE_BATCH_IDX
=(
0
)
for
perm
in
0 1
;
do
if
[
$TEST_SPLITKV
-eq
1
]
;
then
for
bias
in
"n"
"e"
"a"
;
do
NUM_SPLITS+
=(
2 3
)
for
b
in
1 2
;
do
PAGE_BLOCK_SIZE+
=(
128
)
for
hdim
in
64 128 256
;
do
CACHE_BATCH_IDX+
=(
1
)
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
fi
done
done
for
prec
in
"fp16"
"bf16"
;
do
done
for
mode
in
1 0
;
do
done
for
perm
in
0 1
;
do
set
+x
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
for
num_splits
in
"
${
NUM_SPLITS
[@]
}
"
;
do
for
page_block_size
in
"
${
PAGE_BLOCK_SIZE
[@]
}
"
;
do
for
cache_batch_idx
in
"
${
CACHE_BATCH_IDX
[@]
}
"
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
;
done
;
done
done
;
}
run_fp8_tests
()
{
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
}
run_fp16_appendkv_tests
()
{
for
s
in
$(
seq
63 1 65
)
;
do
for
s_k
in
65 129
;
do
for
s_knew
in
0 64
$s_k
;
do
for
hdim
in
32 64 128 256
;
do
for
ri
in
0 1
;
do
for
rdim
in
0 16 32
$hdim
;
do
for
page_block_size
in
0 128
;
do
for
cache_batch_idx
in
0 1
;
do
$EXE
-prec
=
fp16
-b
=
3
-h
=
3
-d
=
$hdim
-s
=
$s
-s_k
=
$s_k
-s_knew
=
$s_knew
-rotary_dim
=
$rdim
-rotary_interleaved
=
$ri
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-iperm
=
1
-operm
=
1
-kname
=
1
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
}
set
-x
run_fp16_bf16_tests
run_fp8_tests
if
[
$TEST_APPENDKV
-eq
1
]
;
then
run_fp16_appendkv_tests
fi
set
+x
\ No newline at end of file
example/ck_tile/01_fmha/utils.hpp
View file @
4885c38a
...
@@ -3,15 +3,17 @@
...
@@ -3,15 +3,17 @@
#pragma once
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdint>
#include <cstdlib>
#include <cstdlib>
#include <functional>
#include <optional>
#include <optional>
#include <ostream>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <tuple>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/span.hpp"
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
assert
(
0
<
count
);
assert
(
0
<
count
);
std
::
vector
<
int32_t
>
seqlens
(
seqlen_min
=
(
0
<
seqlen_min
?
seqlen_min
:
1
);
count
,
seqlen_max
>
0
?
(
seqlen_avg
<
seqlen_max
?
seqlen_avg
:
seqlen_max
)
:
seqlen_avg
);
seqlen_max
=
(
0
<
seqlen_max
?
seqlen_max
:
std
::
numeric_limits
<
int32_t
>::
max
());
assert
(
seqlen_min
<=
seqlen_max
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
{
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
{
const
size_type
to_decrease
=
next_idx
();
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is
always greater than 0
// make sure each elements of seqlens is
in range [seqlen_min, seqlen_max]
if
(
seqlens
[
to_decrease
]
==
1
)
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
{
{
continue
;
continue
;
}
}
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
if
(
seqlen_max
>
0
&&
seqlens
[
to_increase
]
>=
seqlen_max
)
if
(
seqlens
[
to_increase
]
>=
seqlen_max
)
{
{
continue
;
continue
;
}
}
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
int32_t
seqlen_max
=
-
1
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_max
,
seed
));
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_min
,
seqlen_max
,
seed
));
}
// return random integer generated uniformly in range [low, high]
template
<
typename
Int
=
int
>
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>
,
Int
>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
return
dist
(
engine
);
}
// return random integers generated uniformly in range [low, high]
template
<
typename
Int
,
typename
ForwardIterator
>
auto
randints
(
ForwardIterator
first
,
ForwardIterator
last
,
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
std
::
generate
(
first
,
last
,
[
&
]
{
return
dist
(
engine
);
});
}
}
/*
/*
...
@@ -112,16 +144,45 @@ decode_seqlen(mode_enum mode,
...
@@ -112,16 +144,45 @@ decode_seqlen(mode_enum mode,
std
::
string
q_val
,
std
::
string
q_val
,
std
::
string
k_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
std
::
string
k_pad_val
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
use_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if
(
mode
==
mode_enum
::
batch
)
if
(
mode
==
mode_enum
::
batch
)
{
{
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
k
<
0
?
q
:
k
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
[
&
]
{
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
use_kvcache
)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
seqlen_ks
.
end
(),
seqlen_k_min
,
seqlen_k_max
,
seed
);
return
seqlen_ks
;
}
return
seqlen_ks
;
}();
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
}
else
else
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
s_q
.
push_back
(
q
);
s_q
.
push_back
(
q
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_kpad
.
push_back
(
kp
);
s_kpad
.
push_back
(
kp
);
// s_k should be greater than or equal to seqlen_k_min
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
idx
++
;
idx
++
;
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
{
{
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
}
}
if
(
idx
<
batch
)
if
(
idx
<
batch
)
{
{
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
r
=
std
::
atoi
(
v
);
r
=
std
::
atoi
(
v
);
return
r
;
return
r
;
}
}
template
<
typename
RandomAccessIterator
,
typename
Int
>
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
iota_shuffle
(
RandomAccessIterator
first
,
RandomAccessIterator
last
,
Int
value
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
std
::
iota
(
first
,
last
,
value
);
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
shuffle
(
first
,
last
,
engine
);
}
include/ck/ck.hpp
View file @
4885c38a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -153,8 +153,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set
stochastic rounding
as default for f8 conversions
// set
rounding to nearest even
as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION
1
#define CK_USE_SR_F8_CONVERSION
0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
4885c38a
...
@@ -67,8 +67,8 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -67,8 +67,8 @@ struct BlockwiseGemmXdlops_pipeline_base
KPerBlock
,
KPerBlock
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
A
_K1
,
A
BlockTransferSrcScalarPerVector
,
B
_K1
,
B
BlockTransferSrcScalarPerVector
,
A_K1
,
A_K1
,
B_K1
,
B_K1
,
MRepeat
,
MRepeat
,
...
...
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
View file @
4885c38a
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
#pragma once
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
...
@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
UnaryOp2
unary_op2_
{};
UnaryOp2
unary_op2_
{};
};
};
using
ScaleScalePass
=
UnaryCombinedOp
<
Scale
,
Scale
,
PassThrough
>
;
using
ScaleScaleRelu
=
UnaryCombinedOp
<
Scale
,
Scale
,
Relu
>
;
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
4885c38a
...
@@ -752,11 +752,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -752,11 +752,18 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
// if constexpr(BBlockLdsExtraN)
// {
// return make_naive_tensor_descriptor(
// make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
// make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
// }
// else
if
constexpr
(
BBlockLdsExtraN
&&
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
K
PerBlock
+
BBlockLdsExtraN
>
{},
I1
));
make_tuple
(
BK1Number
*
Number
<
N
PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{}
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
...
@@ -1318,9 +1325,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1318,9 +1325,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
...
...
include/ck_tile/core/config.hpp
View file @
4885c38a
...
@@ -46,6 +46,7 @@
...
@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...
@@ -156,6 +157,14 @@
...
@@ -156,6 +157,14 @@
#endif
#endif
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#define CK_TILE_DEBUG_LOG 0
#endif
#endif
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
4885c38a
...
@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
...
@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard
=
0
,
// rtn
standard
=
0
,
// rtn
truncate_with_nan
,
truncate_with_nan
,
truncate
,
truncate
,
standard_asm
,
};
};
template
<
bf16_rounding_mode
rounding
=
template
<
bf16_rounding_mode
rounding
=
...
@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
...
@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return
uint16_t
(
u
.
int32
>>
16
);
return
uint16_t
(
u
.
int32
>>
16
);
}
}
CK_TILE_HOST
constexpr
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
return
float_to_bf16_rtn_raw
(
f
);
}
CK_TILE_DEVICE
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
static
constexpr
uint32_t
FP32_NAN
=
0x7fff0000
;
static
constexpr
uint32_t
ROUND_BIAS_FOR_BF16
=
0x7fff
;
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
check_nan
;
uint32_t
tmp
;
asm
volatile
(
"
\n
\
v_cmp_u_f32 %0, %2, %2
\n
\
v_bfe_u32 %1, %2, 16, 1
\n
\
v_add3_u32 %1, %2, %1, %3
\n
\
v_cndmask_b32 %2, %1, %4, %0
\n
\
v_lshrrev_b32 %2, 16, %2
\n
\
"
:
"=s"
(
check_nan
),
"+v"
(
tmp
),
"+v"
(
u
.
fp32
)
:
"v"
(
ROUND_BIAS_FOR_BF16
),
"v"
(
FP32_NAN
));
return
uint16_t
(
u
.
int32
);
}
// Truncate instead of rounding, preserving SNaN
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
...
@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
...
@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{
{
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
return
float_to_bf16_rtn_raw
(
f
);
return
float_to_bf16_rtn_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard_asm
)
return
float_to_bf16_rtn_asm
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
return
float_to_bf16_truc_nan_raw
(
f
);
else
else
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment