Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e536d321
"vscode:/vscode.git/clone" did not exist on "a2091b7071b17c431446d78336ed7fabfdbd7baf"
Commit
e536d321
authored
Sep 04, 2024
by
illsilin
Browse files
merge from public repo
parents
829e0eb3
52410b49
Changes
76
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2165 additions
and
400 deletions
+2165
-400
CMakeLists.txt
CMakeLists.txt
+13
-1
Jenkinsfile
Jenkinsfile
+35
-9
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+37
-3
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+13
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+355
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+127
-61
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+604
-164
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+293
-30
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+16
-11
example/ck_tile/01_fmha/rotary.hpp
example/ck_tile/01_fmha/rotary.hpp
+84
-0
example/ck_tile/01_fmha/script/benchmark_bwd.sh
example/ck_tile/01_fmha/script/benchmark_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/benchmark_fwd.sh
example/ck_tile/01_fmha/script/benchmark_fwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+94
-41
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+97
-13
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+373
-39
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+5
-5
No files found.
CMakeLists.txt
View file @
e536d321
...
@@ -26,7 +26,19 @@ set(version 1.1.0)
...
@@ -26,7 +26,19 @@ set(version 1.1.0)
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX HIP
)
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX HIP
)
include
(
CTest
)
include
(
CTest
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
if
(
NOT CK_USE_ALTERNATIVE_PYTHON
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
else
()
message
(
"Using alternative python version"
)
set
(
EXTRA_PYTHON_PATH
)
string
(
REPLACE
"/bin/python3.8"
""
EXTRA_PYTHON_PATH
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
message
(
"alternative python path is:
${
EXTRA_PYTHON_PATH
}
"
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
add_definitions
(
-DPython3_EXECUTABLE=
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
set
(
Python3_EXECUTABLE
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
set
(
PYTHON_EXECUTABLE
"
${
CK_USE_ALTERNATIVE_PYTHON
}
"
)
set
(
ENV{LD_LIBRARY_PATH}
"
${
EXTRA_PYTHON_PATH
}
/lib:$ENV{LD_LIBRARY_PATH}"
)
endif
()
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
PROJECT_SOURCE_DIR
}
/cmake"
)
...
...
Jenkinsfile
View file @
e536d321
...
@@ -262,10 +262,19 @@ def cmake_build(Map conf=[:]){
...
@@ -262,10 +262,19 @@ def cmake_build(Map conf=[:]){
// reduce parallelism when compiling, clang uses too much memory
// reduce parallelism when compiling, clang uses too much memory
def
nt
=
nthreads
()
def
nt
=
nthreads
()
def
cmd
def
cmd
def
setup_cmd
def
build_cmd
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
def
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
def
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j${nt} ${config_targets}"
)
echo
"running ninja build trace"
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake -G Ninja ${setup_args} .. "
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} ninja -j${nt} ${config_targets}"
)
}
else
{
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j${nt} ${config_targets}"
)
}
cmd
=
conf
.
get
(
"cmd"
,
"""
cmd
=
conf
.
get
(
"cmd"
,
"""
${setup_cmd}
${setup_cmd}
${build_cmd}
${build_cmd}
...
@@ -281,7 +290,19 @@ def cmake_build(Map conf=[:]){
...
@@ -281,7 +290,19 @@ def cmake_build(Map conf=[:]){
echo
cmd
echo
cmd
dir
(
"build"
){
dir
(
"build"
){
//build CK
sh
cmd
sh
cmd
//run tests
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
sh
"/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
archiveArtifacts
"ck_build_trace.json"
sh
"ninja test"
}
else
{
sh
"make check"
}
}
}
}
// Only archive from master or develop
// Only archive from master or develop
...
@@ -543,7 +564,7 @@ def Build_CK(Map conf=[:]){
...
@@ -543,7 +564,7 @@ def Build_CK(Map conf=[:]){
cmake_build
(
conf
)
cmake_build
(
conf
)
dir
(
"build"
){
dir
(
"build"
){
//run tests and examples
//run tests and examples
sh
'make -j check'
//
sh 'make -j check'
if
(
params
.
RUN_PERFORMANCE_TESTS
&&
do_perf_tests
==
0
){
if
(
params
.
RUN_PERFORMANCE_TESTS
&&
do_perf_tests
==
0
){
//we only need the ckProfiler to run the performance tests, so we pack and stash it
//we only need the ckProfiler to run the performance tests, so we pack and stash it
//do not stash profiler on nodes where we don't need to run performance tests
//do not stash profiler on nodes where we don't need to run performance tests
...
@@ -755,7 +776,10 @@ pipeline {
...
@@ -755,7 +776,10 @@ pipeline {
name:
"BUILD_GFX12"
,
name:
"BUILD_GFX12"
,
defaultValue:
false
,
defaultValue:
false
,
description:
"Build CK and run tests on gfx12 (default: OFF)"
)
description:
"Build CK and run tests on gfx12 (default: OFF)"
)
booleanParam
(
name:
"NINJA_BUILD_TRACE"
,
defaultValue:
false
,
description:
"Generate a ninja build trace (default: OFF)"
)
}
}
environment
{
environment
{
dbuser
=
"${dbuser}"
dbuser
=
"${dbuser}"
...
@@ -789,6 +813,7 @@ pipeline {
...
@@ -789,6 +813,7 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"nogpu"
)
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
environment
{
setup_args
=
"NO_CK_BUILD"
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
...
@@ -805,7 +830,7 @@ pipeline {
...
@@ -805,7 +830,7 @@ pipeline {
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
}
}
steps
{
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
archiveArtifacts
"build/ck_cppcheck.log"
archiveArtifacts
"build/ck_cppcheck.log"
cleanWs
()
cleanWs
()
}
}
...
@@ -817,6 +842,7 @@ pipeline {
...
@@ -817,6 +842,7 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"nogpu"
)
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
environment
{
setup_args
=
"NO_CK_BUILD"
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
...
@@ -828,7 +854,7 @@ pipeline {
...
@@ -828,7 +854,7 @@ pipeline {
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
}
}
steps
{
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
cleanWs
()
cleanWs
()
}
}
}
}
...
@@ -957,10 +983,10 @@ pipeline {
...
@@ -957,10 +983,10 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="
gfx1100;
gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && \
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="
gfx1100;
gfx90a" \
-DGPU_TARGETS="gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
}
...
@@ -1064,7 +1090,7 @@ pipeline {
...
@@ -1064,7 +1090,7 @@ pipeline {
options
{
retry
(
1
)
}
options
{
retry
(
1
)
}
agent
{
label
rocmnode
(
"gfx90a"
)}
agent
{
label
rocmnode
(
"gfx90a"
)}
environment
{
environment
{
setup_args
=
"
"" -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On ""
"
setup_args
=
"
NO_CK_BUILD
"
}
}
steps
{
steps
{
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
...
...
example/ck_tile/01_fmha/CMakeLists.txt
View file @
e536d321
# generate a list of kernels, but not actually emit files at config stage
# validate user-specified fmha_fwd API list
set
(
FMHA_FWD_KNOWN_APIS
"fwd;fwd_splitkv;fwd_appendkv"
)
set
(
FMHA_FWD_ENABLE_APIS
"fwd"
CACHE STRING
"semicolon-separated list of APIs to generate (
${
FMHA_FWD_KNOWN_APIS
}
) & link, or
\"
all
\"
."
)
if
(
FMHA_FWD_ENABLE_APIS STREQUAL
"all"
)
set
(
FMHA_FWD_ENABLE_APIS
${
FMHA_FWD_KNOWN_APIS
}
)
endif
()
foreach
(
api
${
FMHA_FWD_ENABLE_APIS
}
)
if
(
NOT
"
${
api
}
"
IN_LIST FMHA_FWD_KNOWN_APIS
)
message
(
FATAL_ERROR
"
${
api
}
isn't a known api:
${
FMHA_FWD_KNOWN_APIS
}
."
)
endif
()
endforeach
()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if
(
NOT
"fwd"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND FMHA_FWD_ENABLE_APIS
"fwd"
)
endif
()
string
(
REPLACE
";"
","
FMHA_FWD_APIS
"
${
FMHA_FWD_ENABLE_APIS
}
"
)
# generate a list of kernels, but not actually emit files at config sta
execute_process
(
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
--api
${
FMHA_FWD_APIS
}
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
)
)
execute_process
(
execute_process
(
...
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
...
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command
(
add_custom_command
(
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
--api
${
FMHA_FWD_APIS
}
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
)
add_custom_command
(
add_custom_command
(
...
@@ -60,6 +80,20 @@ else()
...
@@ -60,6 +80,20 @@ else()
endif
()
endif
()
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if
(
"fwd_splitkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0
)
endif
()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if
(
"fwd_appendkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0
)
endif
()
# Allow comparing floating points directly in order to check sentinel values
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
e536d321
...
@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
...
@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
}
}
ROPE_MAP
=
{
"no"
:
"ck_tile::RotaryEmbeddingEnum::NONE"
,
"inter"
:
"ck_tile::RotaryEmbeddingEnum::INTERLEAVED"
,
"half"
:
"ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP
=
{
"no"
:
"rope_enum::none"
,
"inter"
:
"rope_enum::interleaved"
,
"half"
:
"rope_enum::half_rotated"
}
MODE_MAP
=
{
MODE_MAP
=
{
"batch"
:
"false"
,
"batch"
:
"false"
,
"group"
:
"true"
"group"
:
"true"
...
@@ -105,4 +117,4 @@ PIPELINE_ENUM_MAP = {
...
@@ -105,4 +117,4 @@ PIPELINE_ENUM_MAP = {
BOOL_MAP
=
{
BOOL_MAP
=
{
"t"
:
"true"
,
"t"
:
"true"
,
"f"
:
"false"
"f"
:
"false"
}
}
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
0 → 100644
View file @
e536d321
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import
copy
from
dataclasses
import
dataclass
import
fnmatch
import
itertools
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Tuple
from
codegen.cmake_config
import
*
from
codegen.cpp_symbol_map
import
*
from
codegen.ops.fmha_fwd
import
(
FmhaFwdApiTrait
,
DTYPE_BITS
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_HDIM_CASE
,
)
FMHA_FWD_APPENDKV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME
=
"fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API
=
"""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@
dataclass
class
FmhaFwdAppendKVApiTrait
:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
bs
:
int
# tile size along q seqlen
bsk
:
int
# tile size along k seqlen
bd
:
int
# tile size along qk gemm unroll
bdv
:
int
# tile size along kv gemm unroll
vlayout
:
str
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
rope
:
str
# key from ROPE_MAP
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
bs
}
-
{
self
.
bsk
}
-
{
self
.
bd
}
-
{
self
.
bdv
}
-
{
self
.
vlayout
}
-'
+
\
f
'
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
-
{
self
.
rope
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bs
}
!= 0*/'
else
:
return
f
'a.seqlen_q %
{
self
.
bs
}
== 0'
@
property
def
skcheck
(
self
)
->
str
:
# we do not check all the values in a.seqlen_k_ptr
return
'true'
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bd
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bd
}
== 0'
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bdv
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bdv
}
== 0'
@
dataclass
class
FmhaFwdAppendKVPipeline
:
F_vlayout
:
str
# row/col
F_spad
:
str
# true/false
F_skpad
:
str
#
F_dpad
:
str
#
F_dvpad
:
str
#
F_rope
:
str
# key from ROPE_MAP
F_pagedkv
:
str
# t/f
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_skpad
==
't'
:
n
+=
'sk'
if
self
.
F_dpad
==
't'
:
n
+=
'd'
if
self
.
F_dvpad
==
't'
:
n
+=
'dv'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
'v
{
self
.
F_vlayout
[
0
]
}
'
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_rope
!=
'no'
:
n
+=
f
'_
{
self
.
F_rope
}
'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
class
FmhaFwdAppendKVApiPool
:
def
__init__
(
self
,
mask_impl
):
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
if
trait
.
hdim
not
in
self
.
pool
[
trait
.
dtype
].
keys
():
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
]
=
list
()
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
@
property
def
api
(
self
)
->
str
:
per_dtypes
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
pool
.
keys
()):
per_hdim_case
=
str
()
for
j
,
hdim
in
enumerate
(
self
.
pool
[
dtype
].
keys
()):
traits
=
self
.
pool
[
dtype
][
hdim
]
inners
=
str
()
for
k
,
trait
in
enumerate
(
traits
):
if_k
=
'if'
if
k
==
0
else
'else if'
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_APPENDKV_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
class
FmhaFwdAppendKVTileSize
:
F_bs
:
int
# tile size along q seqlen
F_bsk
:
int
# tile size along k seqlen
F_bd
:
int
# tile size along qk gemm unroll
F_bdv
:
int
# tile size along kv gemm unroll
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bs
}
x
{
self
.
F_bsk
}
x
{
self
.
F_bd
}
x
{
self
.
F_bdv
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
class
FmhaFwdAppendKVKernel
:
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_tile
:
FmhaFwdAppendKVTileSize
F_pipeline
:
FmhaFwdAppendKVPipeline
mask_impl
:
str
@
property
def
template
(
self
)
->
str
:
kernel_body
=
str
()
return
FMHA_FWD_KERNEL_HEADER
+
\
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bs
=
self
.
F_tile
.
F_bs
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bd
=
self
.
F_tile
.
F_bd
,
F_bdv
=
self
.
F_tile
.
F_bdv
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_rope
=
ROPE_MAP
[
self
.
F_pipeline
.
F_rope
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
)
@
property
def
name
(
self
)
->
str
:
# TODO: we don't encode idx here
return
f
"fmha_fwd_appendkv_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_"
+
\
self
.
F_tile
.
name
+
'_'
+
self
.
F_pipeline
.
name
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdAppendKVApiTrait
:
return
FmhaFwdAppendKVApiTrait
(
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
bs
=
self
.
F_tile
.
F_bs
,
bsk
=
self
.
F_tile
.
F_bsk
,
bd
=
self
.
F_tile
.
F_bd
,
bdv
=
self
.
F_tile
.
F_bdv
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
,
rope
=
self
.
F_pipeline
.
F_rope
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
)
}
else
:
return
None
def
get_fwd_appendkv_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
Tuple
[
FmhaFwdAppendKVApiPool
,
List
[
FmhaFwdAppendKVKernel
]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def
get_pipelines
(
dtype
,
hdim
)
->
List
[
FmhaFwdAppendKVPipeline
]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for
vlayout
in
[
'row'
,
'col'
]:
for
pagedkv
in
[
"t"
,
"f"
]:
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
'f'
,
'f'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'half'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'half'
,
pagedkv
))
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# rope/paged-kv is not supported
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
else
:
assert
False
return
pipelines
gen
=
list
()
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
for
hdim_str
in
d
.
keys
():
tile
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
k
=
FmhaFwdAppendKVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_pipeline
=
pipeline
,
mask_impl
=
mask_impl
)
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
return
(
api_pool
,
gen
)
def
write_single_kernel
(
kernel
:
FmhaFwdAppendKVKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_fwd_appendkv_api
(
api_pool
:
FmhaFwdAppendKVApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_FWD_APPENDKV_API_FILENAME
).
write_text
(
api_pool
.
api
)
def
write_blobs
(
output_dir
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
api_pool
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
write_single_kernel
(
kernel
,
output_dir
)
write_fwd_appendkv_api
(
api_pool
,
output_dir
)
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
with
file_path
.
open
(
'a'
)
as
f
:
_
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
e536d321
...
@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
...
@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
)
)
DTYPE_BITS
=
{
"fp32"
:
32
,
"fp16"
:
16
,
"bf16"
:
16
,
"fp8"
:
8
,
"bf8"
:
8
}
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
...
@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
{F_bias},
false,
false,
{F_lse},
{F_lse},
{F_dropout},
{F_squant},
{F_squant},
{F_pagedkv},
kHasUnevenSplits,
kHasUnevenSplits,
{F_occupancy}>;
{F_occupancy}>;
...
@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
...
@@ -86,7 +93,7 @@ using fmha_kernel =
...
@@ -86,7 +93,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_pipeline,
fmha_epilogue>;
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
using k_ = fmha_kernel;
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
...
@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
...
@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}};
}};
}}
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
#include <iostream>
template<>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if constexpr({F_mode} == false) {{ // batch mode
if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
kernel_runner<false>::run(s, a);
}} else {{
}} else {{
kernel_runner<true>::run(s, a);
kernel_runner<true>::run(s, a);
...
@@ -160,7 +172,7 @@ using fmha_kernel =
...
@@ -160,7 +172,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_pipeline,
fmha_epilogue>;
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
using k_ = fmha_kernel;
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
...
@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
...
@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream>
#include <iostream>
template<>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if (a.num_splits <= 16) {{
if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a);
kernel_runner<4>::run(s, a);
...
@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
...
@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream>
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout
std::cout
...
@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
...
@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
);
);
}}
}}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_
splitkv_
traits t, fmha_fwd_
splitkv_
args a, const ck_tile::stream_config& s){{
float r = -1;
float r = -1;
{F_dispatch}
{F_dispatch}
return r;
return r;
}}
}}
"""
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse})
&& (t.has_dropout == {F_dropout})
&& (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
dropou
t}, {F_
squant
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_
splitkv_
traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
squan
t}, {F_
pagedkv
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
"""
"""
@
dataclass
class
FmhaFwdSplitKVApiTrait
:
pipeline_tag
:
str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
mode
:
str
# value from MODE_MAP
bm0
:
int
# tile size along q seqlen (block size)
bn0
:
int
# tile size along qk seqlen
bk0
:
int
# tile size along qk gemm unroll
bn1
:
int
# tile size along v head_dim
bk1
:
int
# tile size along kv gemm unroll
bk0blen
:
int
vlayout
:
str
mask
:
str
bias
:
str
#
lse
:
str
#
squant
:
str
#
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
bm0
}
-
{
self
.
bn0
}
-
{
self
.
bk0
}
-
{
self
.
bn0
}
-
{
self
.
bk1
}
-
{
self
.
bk0blen
}
-'
+
\
f
'
{
self
.
vlayout
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
lse
}
-
{
self
.
squant
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-'
+
\
f
'
{
self
.
dvpad
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode spad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
spad
==
't'
:
return
'true'
# always support
else
:
return
'true'
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
assert
False
@
property
def
skcheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode skpad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_fp8'
]:
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
assert
False
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
dataclass
@
dataclass
class
FmhaFwdSplitKVPipeline
:
class
FmhaFwdSplitKVPipeline
:
tag
:
str
tag
:
str
...
@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
...
@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
F_dvpad
:
str
#
F_dvpad
:
str
#
F_bias
:
str
# true/false
F_bias
:
str
# true/false
F_lse
:
str
#
F_lse
:
str
#
F_dropout
:
str
#
F_squant
:
str
#
F_squant
:
str
#
F_pagedkv
:
str
# t/f
F_mask
:
str
# value from MASK_MAP
F_mask
:
str
# value from MASK_MAP
@
property
@
property
...
@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
...
@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
else
:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
return
n
@
dataclass
@
dataclass
...
@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
...
@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
self
.
pool
=
dict
()
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
def
register_traits
(
self
,
trait
:
FmhaFwd
SplitKV
ApiTrait
)
->
None
:
# TODO: do we need to check duplication?
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
self
.
pool
[
trait
.
dtype
]
=
dict
()
...
@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
...
@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
dropou
t
=
BOOL_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
squan
t
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
...
@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
...
@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
def
api_trait
(
self
)
->
FmhaFwd
SplitKV
ApiTrait
:
return
FmhaFwdApiTrait
(
return
FmhaFwd
SplitKV
ApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
dtype
=
self
.
F_dtype
,
...
@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
mask
=
self
.
F_pipeline
.
F_mask
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
squant
=
self
.
F_pipeline
.
F_squant
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
,
spad
=
self
.
F_pipeline
.
F_spad
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
...
@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
...
@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
return
FmhaFwdApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
mode
=
self
.
F_mode
,
bm0
=
self
.
F_tile
.
F_bm0
,
bn0
=
self
.
F_tile
.
F_bn0
,
bk0
=
self
.
F_tile
.
F_bk0
,
bn1
=
self
.
F_tile
.
F_bn1
,
bk1
=
self
.
F_tile
.
F_bk1
,
bk0blen
=
self
.
F_tile
.
F_bk0blen
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
)
# TODO: design a more practical way to do it
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
# this is current supported tile size per hdim
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
...
@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
# splitkv kernel donot support dropout
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"f"
]):
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
:
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]
:
# if True:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
else
:
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/
dropout
kernels
# no need lse/
paged-kv
kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
squant
,
'f'
,
mask
))
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
if
pipeline
.
F_pagedkv
==
't'
:
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k
=
Kernel
(
F_idx
=
0
,
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_dtype
=
dtype
,
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
e536d321
This diff is collapsed.
Click to expand it.
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
e536d321
...
@@ -5,10 +5,13 @@
...
@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"
#include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits>
#include <type_traits>
template
<
typename
DataType
>
template
<
typename
DataType
>
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
const
void
*
v_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
rand_val_ptr
;
void
*
rand_val_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
// only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
struct
fmha_fwd_splitkv_args
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
lse_acc_ptr
;
void
*
lse_acc_ptr
;
void
*
o_acc_ptr
;
void
*
o_acc_ptr
;
void
*
lse_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
void
*
o_ptr
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
batch
;
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
num_splits
;
ck_tile
::
index_t
num_splits
;
float
scale_s
;
float
scale_s
;
float
scale_p
;
float
scale_p
;
float
scale_o
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
};
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
struct
fmha_fwd_appendkv_args
{
void
*
q_ptr
;
void
*
k_ptr
;
const
void
*
knew_ptr
;
void
*
v_ptr
;
const
void
*
vnew_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_knew
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
const
void
*
rotary_cos_ptr
;
// only used if 'rotary_dim' > 0
const
void
*
rotary_sin_ptr
;
// only used if 'rotary_dim' > 0
ck_tile
::
index_t
rotary_dim
;
bool
has_mask
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_knew
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_vnew
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_knew
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_vnew
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_knew
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_vnew
;
};
};
template
<
typename
FmhaKernel
>
template
<
typename
FmhaKernel
>
...
@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}
template
<
typename
Kernel
>
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
auto
kargs
=
[
&
]
{
...
@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
seqlen_k_ptr
,
...
@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
);
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
...
@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
num_splits
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
scale_s
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
);
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}
}();
}();
...
@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
}
}
template
<
typename
Kernel
>
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
auto
kargs
=
[
&
]
{
...
@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
}
template
<
typename
Kernel
>
auto
fmha_fwd_appendkv_create_kargs_and_grids
(
fmha_fwd_appendkv_args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
knew_ptr
,
args
.
v_ptr
,
args
.
vnew_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k_ptr
,
args
.
seqlen_knew
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
rotary_cos_ptr
,
args
.
rotary_sin_ptr
,
args
.
rotary_dim
,
args
.
has_mask
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_knew
,
args
.
stride_v
,
args
.
stride_vnew
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_knew
,
args
.
nhead_stride_v
,
args
.
nhead_stride_vnew
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_knew
,
args
.
batch_stride_v
,
args
.
batch_stride_vnew
);
dim3
grids
=
Kernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
seqlen_q
,
args
.
seqlen_knew
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
typename
DataType_
,
...
@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
...
@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN0_
,
ck_tile
::
index_t
kK0_
,
ck_tile
::
index_t
kN1_
,
ck_tile
::
index_t
kK1_
,
ck_tile
::
index_t
kK0BlockLength_
,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
struct
fmha_fwd_splitkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kK1
=
kK1_
;
static
constexpr
ck_tile
::
index_t
kK0BlockLength
=
kK0BlockLength_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_get_name_
();
std
::
string
fmha_fwd_splitkv_get_name_
();
...
@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
...
@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
};
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
ck_tile
::
index_t
kTileSizeS_
,
ck_tile
::
index_t
kTileSizeSk_
,
ck_tile
::
index_t
kTileSizeD_
,
ck_tile
::
index_t
kTileSizeDv_
,
bool
kIsVLayoutRowMajor_
,
bool
kPadS_
,
bool
kPadSk_
,
bool
kPadD_
,
bool
kPadDv_
,
ck_tile
::
RotaryEmbeddingEnum
RotaryEnum_
,
bool
kIsPagedKV_
>
struct
fmha_fwd_appendkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
ck_tile
::
index_t
kTileSizeS
=
kTileSizeS_
;
static
constexpr
ck_tile
::
index_t
kTileSizeSk
=
kTileSizeSk_
;
static
constexpr
ck_tile
::
index_t
kTileSizeD
=
kTileSizeD_
;
static
constexpr
ck_tile
::
index_t
kTileSizeDv
=
kTileSizeDv_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSk
=
kPadSk_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
auto
RotaryEnum
=
RotaryEnum_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
float
fmha_fwd_appendkv_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_appendkv_args
);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
fmha_fwd_traits
struct
fmha_fwd_traits
{
{
...
@@ -508,4 +743,32 @@ struct fmha_fwd_traits
...
@@ -508,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd_splitkv
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_splitkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_group_mode
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
float
fmha_fwd_splitkv
(
fmha_fwd_splitkv_traits
,
fmha_fwd_splitkv_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_appendkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_v_rowmajor
;
rope_enum
rope_type
;
};
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/generate.py
View file @
e536d321
...
@@ -5,25 +5,30 @@
...
@@ -5,25 +5,30 @@
import
argparse
import
argparse
from
enum
import
IntEnum
from
enum
import
IntEnum
from
pathlib
import
Path
from
pathlib
import
Path
import
pkgutil
import
sys
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
codegen.ops
from
codegen.cmake_config
import
*
from
codegen.cmake_config
import
*
from
codegen.ops
import
(
fmha_fwd
,
fmha_fwd_splitkv
,
fmha_bwd
)
class
HandlerId
(
IntEnum
):
class
HandlerId
(
IntEnum
):
LIST_BLOBS
=
0
LIST_BLOBS
=
0
WRITE_BLOBS
=
1
WRITE_BLOBS
=
1
handlers
=
{
# inspect all modules under 'codegen.ops' and register API handlers
'fwd'
:
(
fmha_fwd
.
list_blobs
,
fmha_fwd
.
write_blobs
),
ops
=
[]
'fwd_splitkv'
:
(
fmha_fwd_splitkv
.
list_blobs
,
fmha_fwd_splitkv
.
write_blobs
),
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
(
codegen
.
ops
.
__path__
):
'bwd'
:
(
fmha_bwd
.
list_blobs
,
fmha_bwd
.
write_blobs
),
full_module_name
=
'%s.%s'
%
(
codegen
.
ops
.
__name__
,
module_name
)
}
if
full_module_name
not
in
sys
.
modules
:
ops
.
append
(
importer
.
find_spec
(
module_name
).
loader
.
load_module
(
module_name
))
unwanted_prefix
=
'fmha_'
handlers
=
dict
(
[(
op
.
__name__
[
len
(
unwanted_prefix
):]
if
op
.
__name__
.
startswith
(
unwanted_prefix
)
else
op
.
__name__
,
(
op
.
list_blobs
,
op
.
write_blobs
))
for
op
in
ops
]
)
assert
0
<
len
(
handlers
)
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
if
output_dir
is
None
:
if
output_dir
is
None
:
...
@@ -103,4 +108,4 @@ if __name__ == "__main__":
...
@@ -103,4 +108,4 @@ if __name__ == "__main__":
if
args
.
list_blobs
is
not
None
:
if
args
.
list_blobs
is
not
None
:
list_blobs
(
args
.
list_blobs
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
list_blobs
(
args
.
list_blobs
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
else
:
else
:
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
write_blobs
(
args
.
output_dir
,
api_list
,
args
.
filter
,
int
(
args
.
receipt
),
mask_impl
=
args
.
mask
)
\ No newline at end of file
example/ck_tile/01_fmha/rotary.hpp
0 → 100644
View file @
e536d321
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum
class
rope_enum
{
none
=
0
,
interleaved
=
1
,
half_rotated
=
2
,
};
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
generate_rotary_cos_sin
(
ck_tile
::
index_t
seqlen
,
ck_tile
::
index_t
rotary_dim
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
// return dummy tensors if we won't apply RoPE at all
if
(
rotary_dim
<=
0
)
{
ck_tile
::
HostTensor
<
DataType
>
dummy
({
1
,
1
});
return
std
::
make_tuple
(
dummy
,
dummy
);
}
std
::
mt19937
random_engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
generator
(
0.0
f
,
1.0
f
);
const
ck_tile
::
index_t
num_rows
=
seqlen
*
2
;
const
ck_tile
::
index_t
num_cols
=
rotary_dim
/
2
;
using
std
::
begin
,
std
::
end
;
ck_tile
::
HostTensor
<
float
>
angle
({
num_rows
,
num_cols
});
std
::
generate
(
begin
(
angle
),
end
(
angle
),
[
&
]
{
return
generator
(
random_engine
)
*
2
*
M_PI
;
});
ck_tile
::
HostTensor
<
DataType
>
cos
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
cos
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
cos
(
origin_value
));
});
ck_tile
::
HostTensor
<
DataType
>
sin
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
sin
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
sin
(
origin_value
));
});
return
std
::
make_tuple
(
cos
,
sin
);
}
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
slice_rotary_cos_sin
(
const
ck_tile
::
HostTensor
<
DataType
>&
cos
,
const
ck_tile
::
HostTensor
<
DataType
>&
sin
,
ck_tile
::
index_t
seqlen_offset
,
ck_tile
::
index_t
seqlen
)
{
assert
(
cos
.
get_num_of_dimension
()
==
2
&&
sin
.
get_num_of_dimension
()
==
2
);
assert
(
cos
.
get_length
(
0
)
==
sin
.
get_length
(
0
)
&&
cos
.
get_length
(
1
)
==
sin
.
get_length
(
1
));
assert
(
static_cast
<
std
::
size_t
>
(
seqlen_offset
+
seqlen
)
<=
cos
.
get_length
(
0
));
const
ck_tile
::
index_t
num_rows
=
seqlen
;
const
ck_tile
::
index_t
num_cols
=
cos
.
get_length
(
1
);
ck_tile
::
HostTensor
<
DataType
>
cos_pt
({
num_rows
,
num_cols
});
cos_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
cos
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
ck_tile
::
HostTensor
<
DataType
>
sin_pt
({
num_rows
,
num_cols
});
sin_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
sin
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
return
std
::
make_tuple
(
cos_pt
,
sin_pt
);
}
example/ck_tile/01_fmha/script/benchmark_bwd.sh
View file @
e536d321
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
VALID
=
0
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/benchmark_fwd.sh
View file @
e536d321
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
VALID
=
0
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
e536d321
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
KNAME
=
1
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_WARMUP
=
0
...
...
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
View file @
e536d321
#!/bin/sh
#!/bin/bash
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
KNAME
=
1
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_WARMUP
=
0
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
# mode=0
# mode=0
# export HIP_VISIBLE_DEVICES=4
# export HIP_VISIBLE_DEVICES=4
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
done
TEST_SPLITKV
=
0
done
TEST_APPENDKV
=
0
done
# options:
done
# -s: run splitkv tests
done
# -a: run appendkv tests
done
while
getopts
":sa"
opt
;
do
done
case
"
${
opt
}
"
in
s
)
TEST_SPLITKV
=
1
;;
a
)
TEST_APPENDKV
=
1
;;
*
)
;;
esac
done
done
run_fp16_bf16_tests
()
{
local
NUM_SPLITS
=(
1
)
local
PAGE_BLOCK_SIZE
=(
0
)
local
CACHE_BATCH_IDX
=(
0
)
for
perm
in
0 1
;
do
if
[
$TEST_SPLITKV
-eq
1
]
;
then
for
bias
in
"n"
"e"
"a"
;
do
NUM_SPLITS+
=(
2 3
)
for
b
in
1 2
;
do
PAGE_BLOCK_SIZE+
=(
128
)
for
hdim
in
64 128 256
;
do
CACHE_BATCH_IDX+
=(
1
)
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
fi
done
done
for
prec
in
"fp16"
"bf16"
;
do
done
for
mode
in
1 0
;
do
done
for
perm
in
0 1
;
do
set
+x
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
for
num_splits
in
"
${
NUM_SPLITS
[@]
}
"
;
do
for
page_block_size
in
"
${
PAGE_BLOCK_SIZE
[@]
}
"
;
do
for
cache_batch_idx
in
"
${
CACHE_BATCH_IDX
[@]
}
"
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
;
done
;
done
done
;
}
run_fp8_tests
()
{
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
}
run_fp16_appendkv_tests
()
{
for
s
in
$(
seq
63 1 65
)
;
do
for
s_k
in
65 129
;
do
for
s_knew
in
0 64
$s_k
;
do
for
hdim
in
32 64 128 256
;
do
for
ri
in
0 1
;
do
for
rdim
in
0 16 32
$hdim
;
do
for
page_block_size
in
0 128
;
do
for
cache_batch_idx
in
0 1
;
do
$EXE
-prec
=
fp16
-b
=
3
-h
=
3
-d
=
$hdim
-s
=
$s
-s_k
=
$s_k
-s_knew
=
$s_knew
-rotary_dim
=
$rdim
-rotary_interleaved
=
$ri
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-iperm
=
1
-operm
=
1
-kname
=
1
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
}
set
-x
run_fp16_bf16_tests
run_fp8_tests
if
[
$TEST_APPENDKV
-eq
1
]
;
then
run_fp16_appendkv_tests
fi
set
+x
\ No newline at end of file
example/ck_tile/01_fmha/utils.hpp
View file @
e536d321
...
@@ -3,15 +3,17 @@
...
@@ -3,15 +3,17 @@
#pragma once
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdint>
#include <cstdlib>
#include <cstdlib>
#include <functional>
#include <optional>
#include <optional>
#include <ostream>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <tuple>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/span.hpp"
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
assert
(
0
<
count
);
assert
(
0
<
count
);
std
::
vector
<
int32_t
>
seqlens
(
seqlen_min
=
(
0
<
seqlen_min
?
seqlen_min
:
1
);
count
,
seqlen_max
>
0
?
(
seqlen_avg
<
seqlen_max
?
seqlen_avg
:
seqlen_max
)
:
seqlen_avg
);
seqlen_max
=
(
0
<
seqlen_max
?
seqlen_max
:
std
::
numeric_limits
<
int32_t
>::
max
());
assert
(
seqlen_min
<=
seqlen_max
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
{
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
{
const
size_type
to_decrease
=
next_idx
();
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is
always greater than 0
// make sure each elements of seqlens is
in range [seqlen_min, seqlen_max]
if
(
seqlens
[
to_decrease
]
==
1
)
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
{
{
continue
;
continue
;
}
}
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
if
(
seqlen_max
>
0
&&
seqlens
[
to_increase
]
>=
seqlen_max
)
if
(
seqlens
[
to_increase
]
>=
seqlen_max
)
{
{
continue
;
continue
;
}
}
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
int32_t
seqlen_max
=
-
1
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_max
,
seed
));
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_min
,
seqlen_max
,
seed
));
}
// return random integer generated uniformly in range [low, high]
template
<
typename
Int
=
int
>
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>
,
Int
>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
return
dist
(
engine
);
}
// return random integers generated uniformly in range [low, high]
template
<
typename
Int
,
typename
ForwardIterator
>
auto
randints
(
ForwardIterator
first
,
ForwardIterator
last
,
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
std
::
generate
(
first
,
last
,
[
&
]
{
return
dist
(
engine
);
});
}
}
/*
/*
...
@@ -112,16 +144,45 @@ decode_seqlen(mode_enum mode,
...
@@ -112,16 +144,45 @@ decode_seqlen(mode_enum mode,
std
::
string
q_val
,
std
::
string
q_val
,
std
::
string
k_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
std
::
string
k_pad_val
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
use_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if
(
mode
==
mode_enum
::
batch
)
if
(
mode
==
mode_enum
::
batch
)
{
{
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
k
<
0
?
q
:
k
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
[
&
]
{
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
use_kvcache
)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
seqlen_ks
.
end
(),
seqlen_k_min
,
seqlen_k_max
,
seed
);
return
seqlen_ks
;
}
return
seqlen_ks
;
}();
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
}
else
else
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
s_q
.
push_back
(
q
);
s_q
.
push_back
(
q
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_kpad
.
push_back
(
kp
);
s_kpad
.
push_back
(
kp
);
// s_k should be greater than or equal to seqlen_k_min
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
idx
++
;
idx
++
;
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
{
{
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
}
}
if
(
idx
<
batch
)
if
(
idx
<
batch
)
{
{
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
r
=
std
::
atoi
(
v
);
r
=
std
::
atoi
(
v
);
return
r
;
return
r
;
}
}
template
<
typename
RandomAccessIterator
,
typename
Int
>
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
iota_shuffle
(
RandomAccessIterator
first
,
RandomAccessIterator
last
,
Int
value
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
std
::
iota
(
first
,
last
,
value
);
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
shuffle
(
first
,
last
,
engine
);
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
e536d321
...
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return
false
;
return
false
;
if
constexpr
(
!
((
NDimSpatial
==
1
&&
if
constexpr
(
!
((
NDimSpatial
==
1
&&
(
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
(
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
2
&&
(
NDimSpatial
==
2
&&
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
3
&&
(
NDimSpatial
==
3
&&
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
e536d321
...
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
e536d321
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
e536d321
...
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
...
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return
false
;
return
false
;
}
}
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
e536d321
...
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment