Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d71189ff
Unverified
Commit
d71189ff
authored
Sep 03, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Sep 03, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
f84e2020
73b67f29
Changes
74
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2236 additions
and
435 deletions
+2236
-435
Jenkinsfile
Jenkinsfile
+37
-11
example/ck_tile/01_fmha/CMakeLists.txt
example/ck_tile/01_fmha/CMakeLists.txt
+37
-3
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+13
-1
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
+355
-0
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
+127
-61
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+604
-164
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+293
-30
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+16
-11
example/ck_tile/01_fmha/rotary.hpp
example/ck_tile/01_fmha/rotary.hpp
+84
-0
example/ck_tile/01_fmha/script/benchmark_bwd.sh
example/ck_tile/01_fmha/script/benchmark_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/benchmark_fwd.sh
example/ck_tile/01_fmha/script/benchmark_fwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
+2
-3
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
+94
-41
example/ck_tile/01_fmha/utils.hpp
example/ck_tile/01_fmha/utils.hpp
+97
-13
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
...evice_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+373
-39
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+82
-34
No files found.
Jenkinsfile
View file @
d71189ff
...
@@ -262,10 +262,19 @@ def cmake_build(Map conf=[:]){
...
@@ -262,10 +262,19 @@ def cmake_build(Map conf=[:]){
// reduce parallelism when compiling, clang uses too much memory
// reduce parallelism when compiling, clang uses too much memory
def
nt
=
nthreads
()
def
nt
=
nthreads
()
def
cmd
def
cmd
def
setup_cmd
def
build_cmd
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
def
execute_cmd
=
conf
.
get
(
"execute_cmd"
,
""
)
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
def
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
def
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j${nt} ${config_targets}"
)
echo
"running ninja build trace"
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake -G Ninja ${setup_args} .. "
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} ninja -j${nt} ${config_targets}"
)
}
else
{
setup_cmd
=
conf
.
get
(
"setup_cmd"
,
"${cmake_envs} cmake ${setup_args} .. "
)
build_cmd
=
conf
.
get
(
"build_cmd"
,
"${build_envs} dumb-init make -j${nt} ${config_targets}"
)
}
cmd
=
conf
.
get
(
"cmd"
,
"""
cmd
=
conf
.
get
(
"cmd"
,
"""
${setup_cmd}
${setup_cmd}
${build_cmd}
${build_cmd}
...
@@ -281,7 +290,19 @@ def cmake_build(Map conf=[:]){
...
@@ -281,7 +290,19 @@ def cmake_build(Map conf=[:]){
echo
cmd
echo
cmd
dir
(
"build"
){
dir
(
"build"
){
//build CK
sh
cmd
sh
cmd
//run tests
if
(!
setup_args
.
contains
(
"NO_CK_BUILD"
)){
if
(
setup_args
.
contains
(
"gfx90a"
)
&&
params
.
NINJA_BUILD_TRACE
){
sh
"/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
archiveArtifacts
"ck_build_trace.json"
sh
"ninja test"
}
else
{
sh
"make check"
}
}
}
}
// Only archive from master or develop
// Only archive from master or develop
...
@@ -543,7 +564,7 @@ def Build_CK(Map conf=[:]){
...
@@ -543,7 +564,7 @@ def Build_CK(Map conf=[:]){
cmake_build
(
conf
)
cmake_build
(
conf
)
dir
(
"build"
){
dir
(
"build"
){
//run tests and examples
//run tests and examples
sh
'make -j check'
//
sh 'make -j check'
if
(
params
.
RUN_PERFORMANCE_TESTS
&&
do_perf_tests
==
0
){
if
(
params
.
RUN_PERFORMANCE_TESTS
&&
do_perf_tests
==
0
){
//we only need the ckProfiler to run the performance tests, so we pack and stash it
//we only need the ckProfiler to run the performance tests, so we pack and stash it
//do not stash profiler on nodes where we don't need to run performance tests
//do not stash profiler on nodes where we don't need to run performance tests
...
@@ -684,8 +705,8 @@ def process_results(Map conf=[:]){
...
@@ -684,8 +705,8 @@ def process_results(Map conf=[:]){
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2; RUN_CK_TILE_TESTS=true
CRON_SETTINGS
=
BRANCH_NAME
==
"develop"
?
'''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2; RUN_CK_TILE_TESTS=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false
;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false'''
:
""
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false'''
:
""
pipeline
{
pipeline
{
...
@@ -765,7 +786,10 @@ pipeline {
...
@@ -765,7 +786,10 @@ pipeline {
name:
"BUILD_GFX12"
,
name:
"BUILD_GFX12"
,
defaultValue:
false
,
defaultValue:
false
,
description:
"Build CK and run tests on gfx12 (default: OFF)"
)
description:
"Build CK and run tests on gfx12 (default: OFF)"
)
booleanParam
(
name:
"NINJA_BUILD_TRACE"
,
defaultValue:
false
,
description:
"Generate a ninja build trace (default: OFF)"
)
}
}
environment
{
environment
{
dbuser
=
"${dbuser}"
dbuser
=
"${dbuser}"
...
@@ -799,6 +823,7 @@ pipeline {
...
@@ -799,6 +823,7 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"nogpu"
)
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
environment
{
setup_args
=
"NO_CK_BUILD"
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
...
@@ -815,7 +840,7 @@ pipeline {
...
@@ -815,7 +840,7 @@ pipeline {
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
}
}
steps
{
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
archiveArtifacts
"build/ck_cppcheck.log"
archiveArtifacts
"build/ck_cppcheck.log"
cleanWs
()
cleanWs
()
}
}
...
@@ -827,6 +852,7 @@ pipeline {
...
@@ -827,6 +852,7 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"nogpu"
)
}
agent
{
label
rocmnode
(
"nogpu"
)
}
environment
{
environment
{
setup_args
=
"NO_CK_BUILD"
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
execute_cmd
=
"find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \
...
@@ -838,7 +864,7 @@ pipeline {
...
@@ -838,7 +864,7 @@ pipeline {
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
}
}
steps
{
steps
{
buildHipClangJobAndReboot
(
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
buildHipClangJobAndReboot
(
setup_args:
setup_args
,
setup_cmd:
""
,
build_cmd:
""
,
execute_cmd:
execute_cmd
,
no_reboot:
true
)
cleanWs
()
cleanWs
()
}
}
}
}
...
@@ -967,10 +993,10 @@ pipeline {
...
@@ -967,10 +993,10 @@ pipeline {
}
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="
gfx1100;
gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && \
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="
gfx1100;
gfx90a" \
-DGPU_TARGETS="gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
}
...
@@ -1074,7 +1100,7 @@ pipeline {
...
@@ -1074,7 +1100,7 @@ pipeline {
options
{
retry
(
1
)
}
options
{
retry
(
1
)
}
agent
{
label
rocmnode
(
"gfx90a"
)}
agent
{
label
rocmnode
(
"gfx90a"
)}
environment
{
environment
{
setup_args
=
"
"" -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On ""
"
setup_args
=
"
NO_CK_BUILD
"
}
}
steps
{
steps
{
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
runPerfTest
(
setup_args:
setup_args
,
config_targets:
"ckProfiler"
,
no_reboot:
true
,
build_type:
'Release'
)
...
...
example/ck_tile/01_fmha/CMakeLists.txt
View file @
d71189ff
# generate a list of kernels, but not actually emit files at config stage
# validate user-specified fmha_fwd API list
set
(
FMHA_FWD_KNOWN_APIS
"fwd;fwd_splitkv;fwd_appendkv"
)
set
(
FMHA_FWD_ENABLE_APIS
"fwd"
CACHE STRING
"semicolon-separated list of APIs to generate (
${
FMHA_FWD_KNOWN_APIS
}
) & link, or
\"
all
\"
."
)
if
(
FMHA_FWD_ENABLE_APIS STREQUAL
"all"
)
set
(
FMHA_FWD_ENABLE_APIS
${
FMHA_FWD_KNOWN_APIS
}
)
endif
()
foreach
(
api
${
FMHA_FWD_ENABLE_APIS
}
)
if
(
NOT
"
${
api
}
"
IN_LIST FMHA_FWD_KNOWN_APIS
)
message
(
FATAL_ERROR
"
${
api
}
isn't a known api:
${
FMHA_FWD_KNOWN_APIS
}
."
)
endif
()
endforeach
()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if
(
NOT
"fwd"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND FMHA_FWD_ENABLE_APIS
"fwd"
)
endif
()
string
(
REPLACE
";"
","
FMHA_FWD_APIS
"
${
FMHA_FWD_ENABLE_APIS
}
"
)
# generate a list of kernels, but not actually emit files at config sta
execute_process
(
execute_process
(
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
--api
${
FMHA_FWD_APIS
}
--list_blobs
${
CMAKE_CURRENT_BINARY_DIR
}
/fwd_blob_list.txt
)
)
execute_process
(
execute_process
(
...
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
...
@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command
(
add_custom_command
(
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
OUTPUT
${
FMHA_FWD_GEN_BLOBS
}
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
COMMAND
${
Python3_EXECUTABLE
}
${
CMAKE_CURRENT_LIST_DIR
}
/generate.py
--api
fwd,fwd_splitkv
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
--api
${
FMHA_FWD_APIS
}
--output_dir
${
CMAKE_CURRENT_BINARY_DIR
}
)
)
add_custom_command
(
add_custom_command
(
...
@@ -60,6 +80,20 @@ else()
...
@@ -60,6 +80,20 @@ else()
endif
()
endif
()
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero
)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if
(
"fwd_splitkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0
)
endif
()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if
(
"fwd_appendkv"
IN_LIST FMHA_FWD_ENABLE_APIS
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1
)
else
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0
)
endif
()
# Allow comparing floating points directly in order to check sentinel values
# Allow comparing floating points directly in order to check sentinel values
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal
)
...
...
example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
View file @
d71189ff
...
@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
...
@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
"dropout_wg16_storerandval"
:
"t.has_dropout == true && t.is_store_randval == true"
,
}
}
ROPE_MAP
=
{
"no"
:
"ck_tile::RotaryEmbeddingEnum::NONE"
,
"inter"
:
"ck_tile::RotaryEmbeddingEnum::INTERLEAVED"
,
"half"
:
"ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP
=
{
"no"
:
"rope_enum::none"
,
"inter"
:
"rope_enum::interleaved"
,
"half"
:
"rope_enum::half_rotated"
}
MODE_MAP
=
{
MODE_MAP
=
{
"batch"
:
"false"
,
"batch"
:
"false"
,
"group"
:
"true"
"group"
:
"true"
...
...
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py
0 → 100644
View file @
d71189ff
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import
copy
from
dataclasses
import
dataclass
import
fnmatch
import
itertools
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Tuple
from
codegen.cmake_config
import
*
from
codegen.cpp_symbol_map
import
*
from
codegen.ops.fmha_fwd
import
(
FmhaFwdApiTrait
,
DTYPE_BITS
,
FMHA_FWD_KERNEL_HEADER
,
FMHA_FWD_API_PER_DTYPE
,
FMHA_FWD_API_PER_HDIM_CASE
,
)
FMHA_FWD_APPENDKV_KERNEL_BODY
=
"""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME
=
"fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API
=
"""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@
dataclass
class
FmhaFwdAppendKVApiTrait
:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
bs
:
int
# tile size along q seqlen
bsk
:
int
# tile size along k seqlen
bd
:
int
# tile size along qk gemm unroll
bdv
:
int
# tile size along kv gemm unroll
vlayout
:
str
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
rope
:
str
# key from ROPE_MAP
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
bs
}
-
{
self
.
bsk
}
-
{
self
.
bd
}
-
{
self
.
bdv
}
-
{
self
.
vlayout
}
-'
+
\
f
'
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-
{
self
.
dvpad
}
-
{
self
.
rope
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bs
}
!= 0*/'
else
:
return
f
'a.seqlen_q %
{
self
.
bs
}
== 0'
@
property
def
skcheck
(
self
)
->
str
:
# we do not check all the values in a.seqlen_k_ptr
return
'true'
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bd
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bd
}
== 0'
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bdv
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bdv
}
== 0'
@
dataclass
class
FmhaFwdAppendKVPipeline
:
F_vlayout
:
str
# row/col
F_spad
:
str
# true/false
F_skpad
:
str
#
F_dpad
:
str
#
F_dvpad
:
str
#
F_rope
:
str
# key from ROPE_MAP
F_pagedkv
:
str
# t/f
@
property
def
name
(
self
)
->
str
:
def
pad_name
()
->
str
:
n
=
''
if
self
.
F_spad
==
't'
:
n
+=
's'
if
self
.
F_skpad
==
't'
:
n
+=
'sk'
if
self
.
F_dpad
==
't'
:
n
+=
'd'
if
self
.
F_dvpad
==
't'
:
n
+=
'dv'
if
n
!=
''
:
n
=
'p'
+
n
return
n
pn
=
pad_name
()
n
=
f
'v
{
self
.
F_vlayout
[
0
]
}
'
if
pn
!=
''
:
n
+=
f
'_
{
pn
}
'
if
self
.
F_rope
!=
'no'
:
n
+=
f
'_
{
self
.
F_rope
}
'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
class
FmhaFwdAppendKVApiPool
:
def
__init__
(
self
,
mask_impl
):
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
if
trait
.
hdim
not
in
self
.
pool
[
trait
.
dtype
].
keys
():
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
]
=
list
()
self
.
pool
[
trait
.
dtype
][
trait
.
hdim
].
append
(
copy
.
copy
(
trait
))
@
property
def
api
(
self
)
->
str
:
per_dtypes
=
str
()
for
i
,
dtype
in
enumerate
(
self
.
pool
.
keys
()):
per_hdim_case
=
str
()
for
j
,
hdim
in
enumerate
(
self
.
pool
[
dtype
].
keys
()):
traits
=
self
.
pool
[
dtype
][
hdim
]
inners
=
str
()
for
k
,
trait
in
enumerate
(
traits
):
if_k
=
'if'
if
k
==
0
else
'else if'
inners
=
inners
+
FMHA_FWD_APPENDKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_rope_check
=
ROPE_CHECK_MAP
[
trait
.
rope
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_rope
=
ROPE_MAP
[
trait
.
rope
],
F_bs
=
trait
.
bs
,
F_bsk
=
trait
.
bsk
,
F_bd
=
trait
.
bd
,
F_bdv
=
trait
.
bdv
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
if_j
=
'if'
if
j
==
0
else
'else if'
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_APPENDKV_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
class
FmhaFwdAppendKVTileSize
:
F_bs
:
int
# tile size along q seqlen
F_bsk
:
int
# tile size along k seqlen
F_bd
:
int
# tile size along qk gemm unroll
F_bdv
:
int
# tile size along kv gemm unroll
F_occupancy
:
int
# occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@
property
def
name
(
self
)
->
str
:
return
f
"b
{
self
.
F_bs
}
x
{
self
.
F_bsk
}
x
{
self
.
F_bd
}
x
{
self
.
F_bdv
}
"
+
\
(
""
if
self
.
F_occupancy
==
-
1
else
f
"_o
{
self
.
F_occupancy
}
"
)
@
dataclass
class
FmhaFwdAppendKVKernel
:
F_idx
:
int
# this is not a tunable, but a counter to differentiate symbol
F_hdim
:
int
# hdim
F_dtype
:
str
# data type
F_tile
:
FmhaFwdAppendKVTileSize
F_pipeline
:
FmhaFwdAppendKVPipeline
mask_impl
:
str
@
property
def
template
(
self
)
->
str
:
kernel_body
=
str
()
return
FMHA_FWD_KERNEL_HEADER
+
\
FMHA_FWD_APPENDKV_KERNEL_BODY
.
format
(
F_idx
=
self
.
F_idx
,
F_hdim
=
self
.
F_hdim
,
F_dtype
=
DTYPE_MAP
[
self
.
F_dtype
],
F_bs
=
self
.
F_tile
.
F_bs
,
F_bsk
=
self
.
F_tile
.
F_bsk
,
F_bd
=
self
.
F_tile
.
F_bd
,
F_bdv
=
self
.
F_tile
.
F_bdv
,
F_vlayout
=
LAYOUT_MAP
[
self
.
F_pipeline
.
F_vlayout
],
F_spad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_spad
],
F_skpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_skpad
],
F_dpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_rope
=
ROPE_MAP
[
self
.
F_pipeline
.
F_rope
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
)
@
property
def
name
(
self
)
->
str
:
# TODO: we don't encode idx here
return
f
"fmha_fwd_appendkv_d
{
self
.
F_hdim
}
_
{
self
.
F_dtype
}
_"
+
\
self
.
F_tile
.
name
+
'_'
+
self
.
F_pipeline
.
name
@
property
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdAppendKVApiTrait
:
return
FmhaFwdAppendKVApiTrait
(
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
bs
=
self
.
F_tile
.
F_bs
,
bsk
=
self
.
F_tile
.
F_bsk
,
bd
=
self
.
F_tile
.
F_bd
,
bdv
=
self
.
F_tile
.
F_bdv
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
,
rope
=
self
.
F_pipeline
.
F_rope
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
if
dtype
==
'fp16'
or
dtype
==
'bf16'
:
return
{
'32'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
32
,
32
,
-
1
),
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
),
}
elif
dtype
==
'fp8'
or
dtype
==
'bf8'
:
return
{
'64'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
64
,
64
,
-
1
),
'128'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
128
,
128
,
-
1
),
'256'
:
FmhaFwdAppendKVTileSize
(
64
,
64
,
256
,
256
,
-
1
)
}
else
:
return
None
def
get_fwd_appendkv_blobs
(
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
Tuple
[
FmhaFwdAppendKVApiPool
,
List
[
FmhaFwdAppendKVKernel
]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def
get_pipelines
(
dtype
,
hdim
)
->
List
[
FmhaFwdAppendKVPipeline
]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for
vlayout
in
[
'row'
,
'col'
]:
for
pagedkv
in
[
"t"
,
"f"
]:
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
'f'
,
'f'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'no'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'inter'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
'f'
,
't'
,
't'
,
'f'
,
'half'
,
pagedkv
))
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
vlayout
,
't'
,
't'
,
't'
,
't'
,
'half'
,
pagedkv
))
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# rope/paged-kv is not supported
pipelines
.
append
(
FmhaFwdAppendKVPipeline
(
'col'
,
't'
,
't'
,
't'
,
't'
,
'no'
,
'f'
))
else
:
assert
False
return
pipelines
gen
=
list
()
api_pool
=
FmhaFwdAppendKVApiPool
(
mask_impl
)
for
dtype
in
DTYPE_MAP
.
keys
():
d
=
get_fmha_fwd_appendkv_tile_dict_from_dtype
(
dtype
)
if
d
==
None
:
continue
for
hdim_str
in
d
.
keys
():
tile
=
d
[
hdim_str
]
hdim
=
int
(
hdim_str
)
for
pipeline
in
get_pipelines
(
dtype
,
hdim
):
k
=
FmhaFwdAppendKVKernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_tile
=
tile
,
F_pipeline
=
pipeline
,
mask_impl
=
mask_impl
)
if
kernel_filter
!=
None
:
if
not
fnmatch
.
fnmatch
(
k
.
name
,
kernel_filter
):
continue
if
receipt
==
2
:
cond
=
dtype
in
[
'fp16'
,
'bf16'
]
cond
&=
pipeline
.
F_vlayout
==
'row'
if
not
cond
:
continue
api_pool
.
register_traits
(
k
.
api_trait
())
gen
.
append
(
k
)
return
(
api_pool
,
gen
)
def
write_single_kernel
(
kernel
:
FmhaFwdAppendKVKernel
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
kernel
.
filename
).
write_text
(
kernel
.
template
)
def
write_fwd_appendkv_api
(
api_pool
:
FmhaFwdAppendKVApiPool
,
autogen_dir
:
Path
)
->
None
:
(
autogen_dir
/
FMHA_FWD_APPENDKV_API_FILENAME
).
write_text
(
api_pool
.
api
)
def
write_blobs
(
output_dir
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
api_pool
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
write_single_kernel
(
kernel
,
output_dir
)
write_fwd_appendkv_api
(
api_pool
,
output_dir
)
def
list_blobs
(
file_path
:
Path
,
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
with
file_path
.
open
(
'a'
)
as
f
:
_
,
kernels
=
get_fwd_appendkv_blobs
(
kernel_filter
,
receipt
,
mask_impl
)
for
kernel
in
kernels
:
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
kernel
.
filename
)
+
"
\n
"
)
f
.
write
(
str
(
file_path
.
parent
/
GEN_DIR
/
FMHA_FWD_APPENDKV_API_FILENAME
)
+
"
\n
"
)
\ No newline at end of file
example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py
View file @
d71189ff
...
@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
...
@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
)
)
DTYPE_BITS
=
{
"fp32"
:
32
,
"fp16"
:
16
,
"bf16"
:
16
,
"fp8"
:
8
,
"bf8"
:
8
}
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
FMHA_FWD_SPLITKV_PIPELINE_MAP
=
{
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
"qr_async"
:
"ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync"
,
...
@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
...
@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias},
{F_bias},
false,
false,
{F_lse},
{F_lse},
{F_dropout},
{F_squant},
{F_squant},
{F_pagedkv},
kHasUnevenSplits,
kHasUnevenSplits,
{F_occupancy}>;
{F_occupancy}>;
...
@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
...
@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
...
@@ -86,7 +93,7 @@ using fmha_kernel =
...
@@ -86,7 +93,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_pipeline,
fmha_epilogue>;
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
using k_ = fmha_kernel;
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
...
@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
...
@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}};
}};
}}
}}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
#include <iostream>
template<>
template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if constexpr({F_mode} == false) {{ // batch mode
if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
// we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a);
kernel_runner<false>::run(s, a);
}} else {{
}} else {{
kernel_runner<true>::run(s, a);
kernel_runner<true>::run(s, a);
...
@@ -160,7 +172,7 @@ using fmha_kernel =
...
@@ -160,7 +172,7 @@ using fmha_kernel =
fmha_pipeline,
fmha_pipeline,
fmha_epilogue>;
fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
static void run(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
using k_ = fmha_kernel;
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
...
@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
...
@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream>
#include <iostream>
template<>
template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if (a.num_splits <= 16) {{
if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a);
kernel_runner<4>::run(s, a);
...
@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
...
@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream>
#include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_
splitkv_
args a)
{{
{{
if(s.log_level_ > 0)
if(s.log_level_ > 0)
std::cout
std::cout
...
@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
...
@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
);
);
}}
}}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float fmha_fwd_splitkv(fmha_fwd_
splitkv_
traits t, fmha_fwd_
splitkv_
args a, const ck_tile::stream_config& s){{
float r = -1;
float r = -1;
{F_dispatch}
{F_dispatch}
return r;
return r;
}}
}}
"""
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse})
&& (t.has_dropout == {F_dropout})
&& (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
=
""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
((a.block_table_ptr != nullptr) == {F_pagedkv}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
dropou
t}, {F_
squant
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits_ = fmha_fwd_
splitkv_
traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_
squan
t}, {F_
pagedkv
}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}}
}}
"""
"""
@
dataclass
class
FmhaFwdSplitKVApiTrait
:
pipeline_tag
:
str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim
:
str
dtype
:
str
# data type
mode
:
str
# value from MODE_MAP
bm0
:
int
# tile size along q seqlen (block size)
bn0
:
int
# tile size along qk seqlen
bk0
:
int
# tile size along qk gemm unroll
bn1
:
int
# tile size along v head_dim
bk1
:
int
# tile size along kv gemm unroll
bk0blen
:
int
vlayout
:
str
mask
:
str
bias
:
str
#
lse
:
str
#
squant
:
str
#
spad
:
str
skpad
:
str
dpad
:
str
dvpad
:
str
pagedkv
:
str
@
property
def
name
(
self
)
->
str
:
return
f
'
{
self
.
hdim
}
-
{
self
.
dtype
}
-
{
self
.
mode
}
-
{
self
.
bm0
}
-
{
self
.
bn0
}
-
{
self
.
bk0
}
-
{
self
.
bn0
}
-
{
self
.
bk1
}
-
{
self
.
bk0blen
}
-'
+
\
f
'
{
self
.
vlayout
}
-
{
self
.
mask
}
-
{
self
.
bias
}
-
{
self
.
lse
}
-
{
self
.
squant
}
-
{
self
.
spad
}
-
{
self
.
skpad
}
-
{
self
.
dpad
}
-'
+
\
f
'
{
self
.
dvpad
}
-
{
self
.
pagedkv
}
'
@
property
def
scheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode spad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
spad
==
't'
:
return
'true'
# always support
else
:
return
'true'
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
spad
==
't'
:
return
f
'true /*a.seqlen_q %
{
self
.
bm0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_q %
{
self
.
bm0
}
== 0'
else
:
assert
False
@
property
def
skcheck
(
self
)
->
str
:
if
self
.
mode
==
'group'
:
return
'true/*group mode skpad always true*/'
# group mode only generate spad/skpad == true
if
self
.
pipeline_tag
==
'qr_async'
:
if
self
.
skpad
==
't'
:
return
f
'a.seqlen_k == 0 || a.seqlen_k %
{
self
.
bn0
}
!= 0'
else
:
return
f
'a.seqlen_k != 0 && a.seqlen_k %
{
self
.
bn0
}
== 0'
elif
self
.
pipeline_tag
in
[
'qr'
,
'qr_fp8'
]:
if
self
.
skpad
==
't'
:
return
f
'true /*a.seqlen_k %
{
self
.
bn0
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.seqlen_k %
{
self
.
bn0
}
== 0'
else
:
assert
False
@
property
def
dcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dpad
==
't'
:
return
f
'a.hdim_q %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dpad
==
't'
:
return
f
'true /*a.hdim_q %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_q %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
property
def
dvcheck
(
self
)
->
str
:
if
self
.
pipeline_tag
==
'qr_async'
:
vec
=
int
((
32
*
4
)
/
DTYPE_BITS
[
self
.
dtype
])
if
self
.
dvpad
==
't'
:
return
f
'a.hdim_v %
{
vec
}
== 0'
else
:
assert
False
elif
self
.
pipeline_tag
in
[
'qr'
]:
if
self
.
dvpad
==
't'
:
return
f
'true /*a.hdim_v %
{
self
.
bk0blen
}
!= 0*/'
# TODO: order of get_pipelines() matters! (ugly)
else
:
return
f
'a.hdim_v %
{
self
.
bk0blen
}
== 0'
else
:
assert
False
@
dataclass
@
dataclass
class
FmhaFwdSplitKVPipeline
:
class
FmhaFwdSplitKVPipeline
:
tag
:
str
tag
:
str
...
@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
...
@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
F_dvpad
:
str
#
F_dvpad
:
str
#
F_bias
:
str
# true/false
F_bias
:
str
# true/false
F_lse
:
str
#
F_lse
:
str
#
F_dropout
:
str
#
F_squant
:
str
#
F_squant
:
str
#
F_pagedkv
:
str
# t/f
F_mask
:
str
# value from MASK_MAP
F_mask
:
str
# value from MASK_MAP
@
property
@
property
...
@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
...
@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
else
:
else
:
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_mask
!=
'no'
:
n
+=
f
'_m
{
self
.
F_mask
[
0
]
}
'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_lse
==
't'
:
n
+=
'_lse'
if
self
.
F_dropout
==
't'
:
n
+=
'_dropout'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_squant
==
't'
:
n
+=
'_squant'
if
self
.
F_pagedkv
==
't'
:
n
+=
'_pagedkv'
return
n
return
n
@
dataclass
@
dataclass
...
@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
...
@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
self
.
pool
=
dict
()
self
.
pool
=
dict
()
self
.
mask_impl
=
mask_impl
self
.
mask_impl
=
mask_impl
def
register_traits
(
self
,
trait
:
FmhaFwdApiTrait
)
->
None
:
def
register_traits
(
self
,
trait
:
FmhaFwd
SplitKV
ApiTrait
)
->
None
:
# TODO: do we need to check duplication?
# TODO: do we need to check duplication?
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
if
trait
.
dtype
not
in
self
.
pool
.
keys
():
self
.
pool
[
trait
.
dtype
]
=
dict
()
self
.
pool
[
trait
.
dtype
]
=
dict
()
...
@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
...
@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
inners
=
inners
+
FMHA_FWD_SPLITKV_API_INNER_DISPATCH
.
format
(
F_if
=
if_k
,
F_mode
=
MODE_MAP
[
trait
.
mode
],
F_vlayout
=
LAYOUT_MAP
[
trait
.
vlayout
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
trait
.
pipeline_tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_mask_check
=
get_mask_check_map
(
self
.
mask_impl
)[
trait
.
mask
],
F_bias_check
=
BIAS_CHECK_MAP
[
trait
.
bias
],
F_bias
=
BIAS_MAP
[
trait
.
bias
],
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
dropou
t
=
BOOL_MAP
[
trait
.
dropout
]
,
F_lse
=
BOOL_MAP
[
trait
.
lse
],
F_
squan
t
=
BOOL_MAP
[
trait
.
squant
],
F_pagedkv
=
BOOL_MAP
[
trait
.
pagedkv
],
F_squant
=
BOOL_MAP
[
trait
.
squant
],
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_scheck
=
trait
.
scheck
,
F_skcheck
=
trait
.
skcheck
,
F_dcheck
=
trait
.
dcheck
,
F_dvcheck
=
trait
.
dvcheck
,
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_spad
=
BOOL_MAP
[
trait
.
spad
],
F_skpad
=
BOOL_MAP
[
trait
.
skpad
],
F_dpad
=
BOOL_MAP
[
trait
.
dpad
],
F_dvpad
=
BOOL_MAP
[
trait
.
dvpad
],
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_bm0
=
trait
.
bm0
,
F_bn0
=
trait
.
bn0
,
F_bk0
=
trait
.
bk0
,
F_bn1
=
trait
.
bn1
,
F_bk1
=
trait
.
bk1
,
F_bk0blen
=
trait
.
bk0blen
,
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
F_hdim
=
hdim
,
F_dtype
=
DTYPE_MAP
[
dtype
])
...
@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_dvpad
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dvpad
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_bias
=
BIAS_MAP
[
self
.
F_pipeline
.
F_bias
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_lse
=
BOOL_MAP
[
self
.
F_pipeline
.
F_lse
],
F_dropout
=
BOOL_MAP
[
self
.
F_pipeline
.
F_dropout
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_squant
=
BOOL_MAP
[
self
.
F_pipeline
.
F_squant
],
F_pagedkv
=
BOOL_MAP
[
self
.
F_pipeline
.
F_pagedkv
],
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_occupancy
=
self
.
F_tile
.
F_occupancy
,
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_pipeline_enum
=
PIPELINE_ENUM_MAP
[
self
.
F_pipeline
.
tag
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
F_mask
=
get_mask_map
(
self
.
mask_impl
)[
self
.
F_pipeline
.
F_mask
],
...
@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
def
api_trait
(
self
)
->
FmhaFwd
SplitKV
ApiTrait
:
return
FmhaFwdApiTrait
(
return
FmhaFwd
SplitKV
ApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
dtype
=
self
.
F_dtype
,
...
@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
...
@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
mask
=
self
.
F_pipeline
.
F_mask
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
squant
=
self
.
F_pipeline
.
F_squant
,
pagedkv
=
self
.
F_pipeline
.
F_pagedkv
,
spad
=
self
.
F_pipeline
.
F_spad
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
...
@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
...
@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
def
filename
(
self
)
->
str
:
def
filename
(
self
)
->
str
:
return
self
.
name
+
".cpp"
return
self
.
name
+
".cpp"
def
api_trait
(
self
)
->
FmhaFwdApiTrait
:
return
FmhaFwdApiTrait
(
pipeline_tag
=
self
.
F_pipeline
.
tag
,
hdim
=
str
(
self
.
F_hdim
),
dtype
=
self
.
F_dtype
,
mode
=
self
.
F_mode
,
bm0
=
self
.
F_tile
.
F_bm0
,
bn0
=
self
.
F_tile
.
F_bn0
,
bk0
=
self
.
F_tile
.
F_bk0
,
bn1
=
self
.
F_tile
.
F_bn1
,
bk1
=
self
.
F_tile
.
F_bk1
,
bk0blen
=
self
.
F_tile
.
F_bk0blen
,
vlayout
=
self
.
F_pipeline
.
F_vlayout
,
mask
=
self
.
F_pipeline
.
F_mask
,
bias
=
self
.
F_pipeline
.
F_bias
,
lse
=
self
.
F_pipeline
.
F_lse
,
dropout
=
self
.
F_pipeline
.
F_dropout
,
squant
=
self
.
F_pipeline
.
F_squant
,
spad
=
self
.
F_pipeline
.
F_spad
,
skpad
=
self
.
F_pipeline
.
F_skpad
,
dpad
=
self
.
F_pipeline
.
F_dpad
,
dvpad
=
self
.
F_pipeline
.
F_dvpad
)
# TODO: design a more practical way to do it
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
# this is current supported tile size per hdim
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
def
get_fmha_fwd_tile_dict_from_dtype
(
dtype
:
str
)
->
Optional
[
dict
]:
...
@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant
=
't'
if
dtype
==
'fp8'
else
'f'
squant
=
't'
if
dtype
==
'fp8'
else
'f'
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
# splitkv kernel donot support dropout
for
mask
,
bias
,
lse
,
pagedkv
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
,
dropout
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
],
[
"f"
]):
# TODO: use async pipeline when compiler is more stable
if
hdim
==
256
:
if
hdim
==
256
or
hdim
in
[
32
,
64
,
128
]
:
# if True:
# if True:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
else
:
else
:
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
if
receipt
==
1
:
if
receipt
==
1
:
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
dropout
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
pagedkv
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
# no need lse/
dropout
kernels
# no need lse/
paged-kv
kernels
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
for
mask
,
bias
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
()):
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
'f'
,
squant
,
mask
))
pipelines
.
append
(
Pipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
'f'
,
squant
,
'f'
,
mask
))
else
:
else
:
assert
False
assert
False
return
pipelines
return
pipelines
...
@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
...
@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
if
pipeline
.
F_spad
!=
't'
or
pipeline
.
F_skpad
!=
't'
:
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
continue
if
pipeline
.
F_pagedkv
==
't'
:
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k
=
Kernel
(
F_idx
=
0
,
k
=
Kernel
(
F_idx
=
0
,
F_hdim
=
hdim
,
F_hdim
=
hdim
,
F_dtype
=
dtype
,
F_dtype
=
dtype
,
...
...
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
d71189ff
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "fmha_fwd.hpp"
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
#include "utils.hpp"
#include <array>
#include <array>
...
@@ -16,6 +17,10 @@
...
@@ -16,6 +17,10 @@
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()"
#endif
template
<
typename
T
>
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
{
...
@@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[])
...
@@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[])
"seqlen_q. if group-mode, means the average value of seqlen_q
\n
"
"seqlen_q. if group-mode, means the average value of seqlen_q
\n
"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
\n
"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
\n
"
"also with
\"
-s=s0,s1,s2...
\"
comma seperated int to set per batch seqlen(group-mode)"
)
"also with
\"
-s=s0,s1,s2...
\"
comma seperated int to set per batch seqlen(group-mode)"
)
.
insert
(
"s_k"
,
"-1"
,
"seqlen_k, -1 means equal to s"
)
.
insert
(
"s_k"
,
"-1"
,
"seqlen_k (including new key/value), -1 means equal to s"
)
.
insert
(
"s_knew"
,
"0"
,
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly."
)
.
insert
(
"s_kpad"
,
.
insert
(
"s_kpad"
,
"-1"
,
"-1"
,
"seqlen_k stride between 2 tokens, currently used in group-mode only
\n
"
"seqlen_k stride between 2 tokens, currently used in group-mode only
\n
"
...
@@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[])
...
@@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[])
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_seed"
,
"1"
,
"seed for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"drop_offset"
,
"0"
,
"offset for random number generator"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"timer"
,
"gpu"
,
"gpu:gpu timer, cpu:cpu timer"
)
.
insert
(
"rotary_dim"
,
"0"
,
"RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all"
)
.
insert
(
"rotary_interleaved"
,
"1"
,
"whether to apply interleaved RoPE"
)
.
insert
(
"num_splits"
,
.
insert
(
"num_splits"
,
"1"
,
"1"
,
"# of splits for key/value. 0 to determine actual number by heuristic"
)
"# of splits for key/value. 0 to determine actual number by heuristic"
)
.
insert
(
"page_block_size"
,
"0"
,
"paged-kvcache block size. 0 means not use paged-kvcahe"
)
.
insert
(
"cache_batch_idx"
,
"0"
,
"whether to use index map to the kvcache"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"5"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
.
insert
(
"repeat"
,
"20"
,
"number of iterations to benchmark the kernel"
);
...
@@ -244,20 +258,6 @@ int override_num_splits_if_necessary(
...
@@ -244,20 +258,6 @@ int override_num_splits_if_necessary(
return
num_splits
;
return
num_splits
;
}
}
float
fmha_fwd_dispatch
(
fmha_fwd_traits
traits
,
fmha_fwd_args
args
,
const
ck_tile
::
stream_config
&
config
)
{
if
(
1
<
args
.
num_splits
)
{
return
fmha_fwd_splitkv
(
traits
,
args
,
config
);
}
else
{
return
fmha_fwd
(
traits
,
args
,
config
);
}
}
template
<
typename
DataType
>
template
<
typename
DataType
>
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
bool
run
(
const
ck_tile
::
ArgParser
&
arg_parser
)
{
{
...
@@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser)
return
false
;
return
false
;
}
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
std
::
optional
<
uint32_t
>
seed
=
arg_parser
.
get_uint32
(
"seed"
);
if
(
*
seed
==
0
)
{
seed
.
reset
();
}
ck_tile
::
index_t
hdim_q
=
arg_parser
.
get_int
(
"d"
);
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
ck_tile
::
index_t
seqlen_knew
=
arg_parser
.
get_int
(
"s_knew"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API
if
(
seqlen_knew
!=
0
)
{
std
::
cerr
<<
"kvcache is not supported. ignoring the 's_knew' option"
<<
std
::
endl
;
seqlen_knew
=
0
;
}
#endif
if
(
seqlen_knew
<
0
)
{
seqlen_knew
=
randint
<
ck_tile
::
index_t
>
(
1
,
arg_parser
.
get_int
(
"s"
),
seed
);
}
ck_tile
::
index_t
rotary_dim
=
arg_parser
.
get_int
(
"rotary_dim"
);
if
constexpr
(
!
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp16_t
>
||
std
::
is_same_v
<
DataType
,
ck_tile
::
bf16_t
>
))
{
if
(
0
<
rotary_dim
)
{
std
::
cerr
<<
"rotary embedding is only available for data type=fp16|bf16"
<<
std
::
endl
;
return
false
;
}
}
#if !CK_TILE_FMHA_FWD_APPENDKV_API
else
if
(
0
<
rotary_dim
)
{
std
::
cerr
<<
"rotary embedding is not supported. ignoring the 'rotary_dim' option"
<<
std
::
endl
;
rotary_dim
=
0
;
}
#endif
if
(
!
(
rotary_dim
<=
hdim_q
))
{
std
::
cerr
<<
"rotary_dim should be less than or equal to head dim for q"
<<
std
::
endl
;
return
false
;
}
else
if
(
!
(
rotary_dim
%
16
==
0
))
{
std
::
cerr
<<
"only rotary dimensions divisible by 16 are currently supported"
<<
std
::
endl
;
return
false
;
}
ck_tile
::
index_t
page_block_size
=
arg_parser
.
get_int
(
"page_block_size"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
std
::
cerr
<<
"paged-kvcache is not supported. ignoring the 'page_block_size' option"
<<
std
::
endl
;
page_block_size
=
0
;
}
#endif
if
(
!
(
page_block_size
%
128
==
0
))
{
std
::
cerr
<<
"only paged-kvcache block size divisible by 128 are currently supported"
<<
std
::
endl
;
return
false
;
}
bool
use_cache_batch_idx
=
arg_parser
.
get_bool
(
"cache_batch_idx"
);
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
use_cache_batch_idx
)
{
std
::
cerr
<<
"split-kv is not supported. ignoring the 'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
#endif
if
(
0
<
page_block_size
&&
use_cache_batch_idx
)
{
std
::
cerr
<<
"paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<<
std
::
endl
;
use_cache_batch_idx
=
false
;
}
// the input tensor layout for kvcache is same as batch mode
const
bool
need_append_kvcache
=
(
0
<
seqlen_knew
||
0
<
rotary_dim
);
const
bool
use_kvcache
=
(
need_append_kvcache
||
use_cache_batch_idx
||
0
<
page_block_size
);
if
(
use_kvcache
&&
mode
!=
mode_enum
::
batch
)
{
std
::
cerr
<<
"kvcache enabled. ignoring the 'mode' option"
<<
std
::
endl
;
mode
=
mode_enum
::
batch
;
}
auto
[
seqlen_qs
,
seqlen_ks
,
seqlen_kpads
]
=
decode_seqlen
(
mode
,
batch
,
batch
,
arg_parser
.
get_str
(
"s"
),
arg_parser
.
get_str
(
"s"
),
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_k"
),
arg_parser
.
get_str
(
"s_kpad"
));
arg_parser
.
get_str
(
"s_kpad"
),
/*seqlen_k_min=*/
0
<
seqlen_knew
?
seqlen_knew
:
0
,
use_kvcache
);
// compute kvcache seqlen_k (before appending knew/vnew)
auto
cache_seqlen_ks
=
seqlen_ks
;
std
::
transform
(
cache_seqlen_ks
.
begin
(),
cache_seqlen_ks
.
end
(),
cache_seqlen_ks
.
begin
(),
[
&
](
auto
seqlen_k
)
{
return
seqlen_k
-
seqlen_knew
;
});
#if 0
#if 0
// clang-format off
// clang-format off
...
@@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
// clang-format on
#endif
#endif
ck_tile
::
index_t
hdim_q
=
arg_parser
.
get_int
(
"d"
);
ck_tile
::
index_t
hdim_v
=
arg_parser
.
get_int
(
"d_v"
);
if
(
hdim_v
<
0
)
hdim_v
=
hdim_q
;
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
i_perm
=
arg_parser
.
get_bool
(
"iperm"
);
// if true, will be batch * nhead * seqlen * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead * hdim
bool
o_perm
=
arg_parser
.
get_bool
(
"operm"
);
// if false, will be batch * seqlen * nhead * hdim
...
@@ -357,13 +455,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -357,13 +455,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
std
::
string
init_method
=
arg_parser
.
get_str
(
"init"
);
std
::
string
init_method
=
arg_parser
.
get_str
(
"init"
);
std
::
optional
<
uint32_t
>
seed
=
arg_parser
.
get_uint32
(
"seed"
);
if
(
*
seed
==
0
)
const
bool
is_rotary_interleaved
=
arg_parser
.
get_bool
(
"rotary_interleaved"
);
ck_tile
::
index_t
num_splits
=
arg_parser
.
get_int
(
"num_splits"
);
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if
(
num_splits
!=
1
)
{
{
seed
.
reset
();
std
::
cerr
<<
"split-kv is not supported. ignoring the 'num_splits' option"
<<
std
::
endl
;
num_splits
=
1
;
}
}
#endif
int
num_splits
=
arg_parser
.
get_int
(
"num_splits"
);
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_warmup
=
arg_parser
.
get_int
(
"warmup"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
int
stream_repeat
=
arg_parser
.
get_int
(
"repeat"
);
...
@@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
}
const
ck_tile
::
index_t
max_num_page_blocks
=
(
0
<
page_block_size
?
batch
*
std
::
max
(
1
,
ck_tile
::
integer_divide_ceil
(
max_seqlen_k
,
page_block_size
))
:
0
);
// legalize num_splits according to other options
// legalize num_splits according to other options
if
(
num_splits
<
1
)
if
(
num_splits
<
1
)
{
{
...
@@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
std
::
cerr
<<
"num_splits greater than 128 is not supported"
<<
std
::
endl
;
std
::
cerr
<<
"num_splits greater than 128 is not supported"
<<
std
::
endl
;
return
false
;
return
false
;
}
}
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
p_drop
&&
(
1
<
num_splits
||
use_kvcache
))
{
std
::
cerr
<<
"dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
<<
std
::
endl
;
p_drop
=
0.0
f
;
}
#endif
auto
get_lengths
=
[
&
](
bool
permute
,
auto
get_lengths
=
[
&
](
bool
permute
,
ck_tile
::
index_t
b
/*batch*/
,
ck_tile
::
index_t
b
/*batch*/
,
...
@@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
ck_tile
::
HostTensor
<
QDataType
>
q_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
get_lengths
(
i_perm
,
shape_batch
,
nhead
,
shape_seqlen_q
,
hdim_q
));
ck_tile
::
HostTensor
<
KDataType
>
k_host
(
ck_tile
::
HostTensor
<
KDataType
>
k_host
(
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
0
<
page_block_size
?
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_q
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_q
));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile
::
HostTensor
<
KDataType
>
knew_host
(
0
<
seqlen_knew
?
get_lengths
(
i_perm
,
batch
,
nhead_k
,
seqlen_knew
,
hdim_q
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
VDataType
>
v_host
(
ck_tile
::
HostTensor
<
VDataType
>
v_host
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
0
<
page_block_size
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
));
?
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
page_block_size
,
hdim_v
)
:
get_lengths
(
i_perm
,
max_num_page_blocks
,
nhead_k
,
hdim_v
,
page_block_size
))
:
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
shape_seqlen_k
,
hdim_v
)
:
get_lengths
(
i_perm
,
shape_batch
,
nhead_k
,
hdim_v
,
shape_seqlen_k
)));
ck_tile
::
HostTensor
<
VDataType
>
vnew_host
(
0
<
seqlen_knew
?
(
is_v_rowmajor
?
get_lengths
(
i_perm
,
batch
,
nhead_k
,
seqlen_knew
,
hdim_v
)
:
get_lengths
(
i_perm
,
batch
,
nhead_k
,
hdim_v
,
seqlen_knew
))
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
}
/* dummy shape for simplifying code */
);
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host
(
ck_tile
::
HostTensor
<
BiasDataType
>
bias_host
(
bias
.
type
==
bias_enum
::
elementwise_bias
bias
.
type
==
bias_enum
::
elementwise_bias
?
get_lengths
(
i_perm
,
1
,
1
,
shape_seqlen_q
,
shape_seqlen_k
)
?
get_lengths
(
i_perm
,
1
,
1
,
shape_seqlen_q
,
shape_seqlen_k
)
...
@@ -478,12 +608,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -478,12 +608,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
nhead
})
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
nhead
})
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
auto
[
rotary_cos_host
,
rotary_sin_host
]
=
generate_rotary_cos_sin
<
KDataType
>
(
std
::
max
(
shape_seqlen_q
,
shape_seqlen_k
),
rotary_dim
,
seed
);
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
1
<
num_splits
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
shape_batch
,
nhead
,
shape_seqlen_q
}
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
num_splits
,
shape_batch
,
nhead
,
shape_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
}
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
num_splits
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
...
@@ -500,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -500,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
p_drop
>
0
?
get_lengths
(
true
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
p_drop
>
0
?
get_lengths
(
true
,
shape_batch
,
nhead
,
shape_seqlen_q
,
max_seqlen_k
)
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
int32_t
>
block_table_host
(
0
<
page_block_size
?
std
::
array
<
ck_tile
::
index_t
,
2
>
{
batch
,
max_num_page_blocks
/
batch
}
:
std
::
array
<
ck_tile
::
index_t
,
2
>
{
1
,
1
});
ck_tile
::
HostTensor
<
int32_t
>
cache_batch_idx_host
(
use_cache_batch_idx
?
std
::
array
<
ck_tile
::
index_t
,
1
>
{
batch
}
:
std
::
array
<
ck_tile
::
index_t
,
1
>
{
1
});
if
(
init_method
==
"ui"
||
init_method
==
"0"
)
if
(
init_method
==
"ui"
||
init_method
==
"0"
)
{
{
ck_tile
::
FillUniformDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
ck_tile
::
FillUniformDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
}
}
else
if
(
init_method
==
"ni"
)
else
if
(
init_method
==
"ni"
)
{
{
ck_tile
::
FillNormalDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
QDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
KDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
VDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
ck_tile
::
FillNormalDistributionIntegerValue
<
BiasDataType
>
{
-
3.
f
,
3.
f
,
seed
}(
bias_host
);
}
}
else
if
(
init_method
==
"uf"
||
init_method
==
"1"
)
else
if
(
init_method
==
"uf"
||
init_method
==
"1"
)
{
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
0.
f
,
1.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
0.
f
,
1.
f
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
0.
f
,
1.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
0.
f
,
1.
f
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
0.
f
,
1.
f
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
0.
f
,
1.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
0.
f
,
1.
f
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
0.
f
,
1.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
0.
f
,
1.
f
,
seed
}(
bias_host
);
ck_tile
::
FillUniformDistribution
<
BiasDataType
>
{
0.
f
,
1.
f
,
seed
}(
bias_host
);
}
}
else
if
(
init_method
==
"nf"
)
else
if
(
init_method
==
"nf"
)
{
{
ck_tile
::
FillNormalDistribution
<
QDataType
>
{
0.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistribution
<
QDataType
>
{
0.
f
,
3.
f
,
seed
}(
q_host
);
ck_tile
::
FillNormalDistribution
<
KDataType
>
{
0.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistribution
<
KDataType
>
{
0.
f
,
3.
f
,
seed
}(
k_host
);
ck_tile
::
FillNormalDistribution
<
KDataType
>
{
0.
f
,
3.
f
,
seed
}(
knew_host
);
ck_tile
::
FillNormalDistribution
<
VDataType
>
{
0.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistribution
<
VDataType
>
{
0.
f
,
3.
f
,
seed
}(
v_host
);
ck_tile
::
FillNormalDistribution
<
VDataType
>
{
0.
f
,
3.
f
,
seed
}(
vnew_host
);
ck_tile
::
FillNormalDistribution
<
BiasDataType
>
{
0.
f
,
3.
f
,
seed
}(
bias_host
);
ck_tile
::
FillNormalDistribution
<
BiasDataType
>
{
0.
f
,
3.
f
,
seed
}(
bias_host
);
}
}
else
if
(
init_method
==
"tf"
||
init_method
==
"2"
)
else
if
(
init_method
==
"tf"
||
init_method
==
"2"
)
{
{
ck_tile
::
FillTrigValue
<
QDataType
>
{}(
q_host
);
ck_tile
::
FillTrigValue
<
QDataType
>
{}(
q_host
);
ck_tile
::
FillTrigValue
<
KDataType
>
{}(
k_host
);
ck_tile
::
FillTrigValue
<
KDataType
>
{}(
k_host
);
ck_tile
::
FillTrigValue
<
KDataType
>
{}(
knew_host
);
ck_tile
::
FillTrigValue
<
VDataType
>
{}(
v_host
);
ck_tile
::
FillTrigValue
<
VDataType
>
{}(
v_host
);
ck_tile
::
FillTrigValue
<
VDataType
>
{}(
vnew_host
);
ck_tile
::
FillTrigValue
<
BiasDataType
>
{}(
bias_host
);
ck_tile
::
FillTrigValue
<
BiasDataType
>
{}(
bias_host
);
}
}
else
if
(
init_method
==
"ufq"
||
init_method
==
"uf:q"
||
else
if
(
init_method
==
"ufq"
||
init_method
==
"uf:q"
||
...
@@ -540,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -540,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
QDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
q_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
k_host
);
ck_tile
::
FillUniformDistribution
<
KDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
knew_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
v_host
);
ck_tile
::
FillUniformDistribution
<
VDataType
>
{
-
dtype_max
,
dtype_max
,
seed
}(
vnew_host
);
// bias_fp8 = qscale_bias * bias_fp32
// bias_fp8 = qscale_bias * bias_fp32
float
qscale_bias
=
(
dtype_max
/
range_q
)
*
(
dtype_max
/
range_k
);
float
qscale_bias
=
(
dtype_max
/
range_q
)
*
(
dtype_max
/
range_k
);
...
@@ -550,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -550,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if
(
bias
.
type
==
bias_enum
::
alibi
)
if
(
bias
.
type
==
bias_enum
::
alibi
)
{
{
auto
slopes
=
ck_tile
::
get_alibi_slopes
<
SaccDataType
>
(
nhead
);
auto
slopes
=
ck_tile
::
get_alibi_slopes
<
SaccDataType
>
(
nhead
);
assert
(
slopes
.
size
()
==
nhead
);
assert
(
slopes
.
size
()
==
static_cast
<
std
::
size_t
>
(
nhead
)
)
;
if
(
bias
.
rank_info
==
0
)
if
(
bias
.
rank_info
==
0
)
{
{
// alibi in 1*h
// alibi in 1*h
...
@@ -565,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -565,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
}
}
}
iota_shuffle
(
block_table_host
.
begin
(),
block_table_host
.
end
(),
0
);
iota_shuffle
(
cache_batch_idx_host
.
begin
(),
cache_batch_idx_host
.
end
(),
0
);
ck_tile
::
DeviceMem
q_buf
(
q_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
q_buf
(
q_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
k_buf
(
k_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
k_buf
(
k_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
knew_buf
(
knew_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
v_buf
(
v_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
v_buf
(
v_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
vnew_buf
(
vnew_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
bias_buf
(
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
bias_buf
(
bias_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
lse_acc_buf
(
lse_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
lse_acc_buf
(
lse_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_acc_buf
(
o_acc_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_acc_buf
(
o_acc_host
.
get_element_space_size_in_bytes
());
...
@@ -576,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -576,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
o_buf
(
o_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_q
(
seqstart_q_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqstart_k
(
seqstart_k_host
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
seqlen_kpads
[
0
]
<
0
?
0
:
seqlen_ks
.
size
()
*
sizeof
(
int32_t
));
ck_tile
::
DeviceMem
seqlen_k_buf
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
cache_seqlen_k_buf
(
need_append_kvcache
?
cache_seqlen_ks
.
size
()
*
sizeof
(
int32_t
)
:
0
);
ck_tile
::
DeviceMem
rotary_cos_buf
(
rotary_cos_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
rotary_sin_buf
(
rotary_sin_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
randval_buf
(
randval_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
alibi_slope_buf
(
alibi_slope_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
block_table_buf
(
block_table_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
cache_batch_idx_buf
(
cache_batch_idx_host
.
get_element_space_size_in_bytes
());
q_buf
.
ToDevice
(
q_host
.
data
());
q_buf
.
ToDevice
(
q_host
.
data
());
k_buf
.
ToDevice
(
k_host
.
data
());
k_buf
.
ToDevice
(
k_host
.
data
());
knew_buf
.
ToDevice
(
knew_host
.
data
());
v_buf
.
ToDevice
(
v_host
.
data
());
v_buf
.
ToDevice
(
v_host
.
data
());
vnew_buf
.
ToDevice
(
vnew_host
.
data
());
bias_buf
.
ToDevice
(
bias_host
.
data
());
bias_buf
.
ToDevice
(
bias_host
.
data
());
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_q
.
ToDevice
(
seqstart_q_host
.
data
());
seqstart_k
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
data
()
seqstart_k
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
.
data
()
:
seqstart_k_with_padding_host
.
data
());
:
seqstart_k_with_padding_host
.
data
());
seqlen_k_buf
.
ToDevice
(
seqlen_kpads
[
0
]
<
0
?
nullptr
:
seqlen_ks
.
data
());
seqlen_k_buf
.
ToDevice
(
use_kvcache
||
0
<=
seqlen_kpads
[
0
]
?
seqlen_ks
.
data
()
:
nullptr
);
cache_seqlen_k_buf
.
ToDevice
(
need_append_kvcache
?
cache_seqlen_ks
.
data
()
:
nullptr
);
rotary_cos_buf
.
ToDevice
(
rotary_cos_host
.
data
());
rotary_sin_buf
.
ToDevice
(
rotary_sin_host
.
data
());
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
alibi_slope_buf
.
ToDevice
(
alibi_slope_host
.
data
());
block_table_buf
.
ToDevice
(
block_table_host
.
data
());
cache_batch_idx_buf
.
ToDevice
(
cache_batch_idx_host
.
data
());
// clang-format off
// clang-format off
auto
layout_str
=
[
&
](
bool
permute
){
auto
layout_str
=
[
&
](
bool
permute
){
if
(
permute
)
return
std
::
string
(
"bhsd"
);
if
(
permute
)
return
std
::
string
(
"bhsd"
);
else
return
std
::
string
(
"bshd"
);
else
return
std
::
string
(
"bshd"
);
};
};
auto
io_layout
=
[
&
](
bool
iperm_
,
bool
operm_
)
{
auto
io_layout
=
[
&
](
bool
iperm_
,
bool
operm_
)
{
if
(
iperm_
==
operm_
)
return
layout_str
(
iperm_
);
if
(
iperm_
==
operm_
)
return
layout_str
(
iperm_
);
else
return
layout_str
(
iperm_
)
+
std
::
string
(
"-"
)
+
layout_str
(
operm_
);
else
return
layout_str
(
iperm_
)
+
std
::
string
(
"-"
)
+
layout_str
(
operm_
);
};
};
// clang-format on
// clang-format on
...
@@ -609,39 +780,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -609,39 +780,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale_s:"
<<
scale_s
<<
", bias:"
<<
bias
<<
", d:"
<<
hdim_q
<<
"/"
<<
hdim_v
<<
", scale_s:"
<<
scale_s
<<
", bias:"
<<
bias
<<
", p_drop:"
<<
p_drop
<<
", lse:"
<<
lse
<<
", squant:"
<<
squant
<<
", p_drop:"
<<
p_drop
<<
", lse:"
<<
lse
<<
", squant:"
<<
squant
<<
", mask:"
<<
mask
<<
", v:"
<<
vlayout
;
<<
", mask:"
<<
mask
<<
", v:"
<<
vlayout
;
#if CK_TILE_FMHA_FWD_APPENDKV_API
if
(
0
<
rotary_dim
)
{
std
::
cout
<<
", rotary_dim:"
<<
rotary_dim
<<
"("
<<
(
is_rotary_interleaved
?
"inter"
:
"half"
)
<<
")"
;
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
1
<
num_splits
)
if
(
1
<
num_splits
)
{
{
std
::
cout
<<
", num_splits:"
<<
num_splits
;
std
::
cout
<<
", num_splits:"
<<
num_splits
;
}
}
if
(
0
<
page_block_size
)
{
std
::
cout
<<
", page_block_size:"
<<
page_block_size
;
}
if
(
use_cache_batch_idx
)
{
std
::
cout
<<
", cache_batch_idx:"
<<
use_cache_batch_idx
;
}
#endif
std
::
cout
<<
std
::
flush
;
std
::
cout
<<
std
::
flush
;
auto
fmha_traits
=
fmha_fwd_traits
{
hdim_q
,
const
auto
init_traits
=
[
&
](
auto
&
traits
)
{
hdim_v
,
traits
.
hdim_q
=
hdim_q
;
data_type
,
traits
.
hdim_v
=
hdim_v
;
mode
==
mode_enum
::
group
,
traits
.
data_type
=
data_type
;
is_v_rowmajor
,
traits
.
is_v_rowmajor
=
is_v_rowmajor
;
mask
.
type
,
bias
.
type
,
lse
,
p_drop
>
0.0
f
,
squant
};
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_appendkv_traits
,
std
::
decay_t
<
decltype
(
traits
)
>>
)
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
{
return
ck_tile
::
scales
{
scale_p
};
traits
.
rope_type
=
(
0
<
rotary_dim
?
(
is_rotary_interleaved
?
rope_enum
::
interleaved
else
:
rope_enum
::
half_rotated
)
return
ck_tile
::
identity
{};
:
rope_enum
::
none
);
}();
}
else
// fmha_fwd_traits or fmha_splitkv_traits
{
traits
.
is_group_mode
=
(
mode
==
mode_enum
::
group
);
traits
.
mask_type
=
mask
.
type
;
traits
.
bias_type
=
bias
.
type
;
traits
.
has_lse
=
lse
;
traits
.
do_fp8_static_quant
=
squant
;
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_traits
,
std
::
decay_t
<
decltype
(
traits
)
>>
)
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
{
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
traits
.
has_dropout
=
(
p_drop
>
0.0
f
);
ck_tile
::
scales
{
scale_o
});
}
else
}
return
ck_tile
::
identity
{};
};
}();
auto
fmha
_args
=
[
&
,
k_paddings_
=
seqlen_kpads
]()
{
const
auto
init
_args
=
[
&
,
k_paddings_
=
seqlen_kpads
](
auto
&
args
)
{
assert
(
nhead
%
nhead_k
==
0
);
assert
(
nhead
%
nhead_k
==
0
);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
...
@@ -649,11 +838,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -649,11 +838,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
// setup stride_* arguments
// setup stride_* arguments
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_q
=
(
i_perm
?
hdim_q
:
nhead
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_k
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_knew
=
(
i_perm
?
hdim_q
:
nhead_k
*
hdim_q
);
const
ck_tile
::
index_t
stride_v
=
[
&
]()
{
const
ck_tile
::
index_t
stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
if
(
is_v_rowmajor
)
return
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
return
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
else
else
return
i_perm
?
shape_seqlen_k
:
nhead_k
*
shape_seqlen_k
;
return
0
<
page_block_size
?
(
i_perm
?
page_block_size
:
nhead_k
*
page_block_size
)
:
(
i_perm
?
shape_seqlen_k
:
nhead_k
*
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
hdim_v
:
nhead_k
*
hdim_v
;
else
return
i_perm
?
seqlen_knew
:
nhead_k
*
seqlen_knew
;
}();
}();
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_bias
=
(
i_perm
?
shape_seqlen_k
:
1
*
shape_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
const
ck_tile
::
index_t
stride_randval
=
(
max_seqlen_k
);
...
@@ -661,12 +858,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -661,12 +858,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
const
ck_tile
::
index_t
stride_o
=
(
o_perm
?
hdim_v
:
nhead
*
hdim_v
);
// setup nhead_stride_* arguments
// setup nhead_stride_* arguments
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_q
=
(
i_perm
?
shape_seqlen_q
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_k
=
(
0
<
page_block_size
?
(
i_perm
?
page_block_size
*
hdim_q
:
hdim_q
)
:
(
i_perm
?
shape_seqlen_k
*
hdim_q
:
hdim_q
));
const
ck_tile
::
index_t
nhead_stride_knew
=
(
i_perm
?
seqlen_knew
*
hdim_q
:
hdim_q
);
const
ck_tile
::
index_t
nhead_stride_v
=
[
&
]()
{
const
ck_tile
::
index_t
nhead_stride_v
=
[
&
]()
{
if
(
is_v_rowmajor
)
if
(
is_v_rowmajor
)
return
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
;
return
0
<
page_block_size
?
(
i_perm
?
page_block_size
*
hdim_v
:
hdim_v
)
:
(
i_perm
?
shape_seqlen_k
*
hdim_v
:
hdim_v
);
else
else
return
i_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
;
return
0
<
page_block_size
?
(
i_perm
?
hdim_v
*
page_block_size
:
page_block_size
)
:
(
i_perm
?
hdim_v
*
shape_seqlen_k
:
shape_seqlen_k
);
}();
const
ck_tile
::
index_t
nhead_stride_vnew
=
[
&
]()
{
if
(
is_v_rowmajor
)
return
i_perm
?
seqlen_knew
*
hdim_v
:
hdim_v
;
else
return
i_perm
?
hdim_v
*
seqlen_knew
:
seqlen_knew
;
}();
}();
const
ck_tile
::
index_t
nhead_stride_bias
=
const
ck_tile
::
index_t
nhead_stride_bias
=
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
(
i_perm
?
0
*
shape_seqlen_q
*
shape_seqlen_k
:
0
*
shape_seqlen_k
);
...
@@ -677,87 +885,193 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -677,87 +885,193 @@ bool run(const ck_tile::ArgParser& arg_parser)
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
const
ck_tile
::
index_t
nhead_stride_o
=
(
o_perm
?
shape_seqlen_q
*
hdim_v
:
hdim_v
);
// setup batch_stride_* arguments
// setup batch_stride_* arguments
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_q
=
(
nhead
*
shape_seqlen_q
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_k
=
(
nhead_k
*
shape_seqlen_k
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_k
=
const
ck_tile
::
index_t
batch_stride_v
=
(
nhead_k
*
hdim_v
*
shape_seqlen_k
);
(
0
<
page_block_size
?
(
nhead_k
*
page_block_size
*
hdim_q
)
:
(
nhead_k
*
shape_seqlen_k
*
hdim_q
));
const
ck_tile
::
index_t
batch_stride_knew
=
(
nhead_k
*
seqlen_knew
*
hdim_q
);
const
ck_tile
::
index_t
batch_stride_v
=
(
0
<
page_block_size
?
(
nhead_k
*
hdim_v
*
page_block_size
)
:
(
nhead_k
*
hdim_v
*
shape_seqlen_k
));
const
ck_tile
::
index_t
batch_stride_vnew
=
(
nhead_k
*
hdim_v
*
seqlen_knew
);
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_bias
=
(
0
*
nhead
*
shape_seqlen_q
*
shape_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_randval
=
(
nhead
*
shape_seqlen_q
*
max_seqlen_k
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_lse_acc
=
(
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o_acc
=
(
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_o
=
(
nhead
*
shape_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
batch_stride_block_table
=
(
max_num_page_blocks
/
batch
);
// setup split_stride_* arguments (only used in split-kv kernel)
// setup split_stride_* arguments (only used in split-kv kernel)
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
split_stride_lse_acc
=
(
shape_batch
*
nhead
*
shape_seqlen_q
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max_seqlen_q
*
hdim_v
);
const
ck_tile
::
index_t
split_stride_o_acc
=
(
batch
*
nhead
*
max_seqlen_q
*
hdim_v
);
return
fmha_fwd_args
{
q_buf
.
GetDeviceBuffer
(),
args
.
q_ptr
=
q_buf
.
GetDeviceBuffer
();
k_buf
.
GetDeviceBuffer
(),
args
.
k_ptr
=
k_buf
.
GetDeviceBuffer
();
v_buf
.
GetDeviceBuffer
(),
args
.
v_ptr
=
v_buf
.
GetDeviceBuffer
();
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
:
bias_buf
.
GetDeviceBuffer
(),
args
.
batch
=
batch
;
randval_buf
.
GetDeviceBuffer
(),
args
.
seqlen_q
=
shape_seqlen_q
;
// unused in group mode
lse_acc_buf
.
GetDeviceBuffer
(),
args
.
hdim_q
=
hdim_q
;
o_acc_buf
.
GetDeviceBuffer
(),
args
.
hdim_v
=
hdim_v
;
lse_buf
.
GetDeviceBuffer
(),
args
.
nhead_q
=
nhead
;
o_buf
.
GetDeviceBuffer
(),
args
.
nhead_k
=
nhead_k
;
seqstart_q
.
GetDeviceBuffer
(),
seqstart_k
.
GetDeviceBuffer
(),
args
.
stride_q
=
stride_q
;
k_paddings_
[
0
]
<
0
?
nullptr
:
seqlen_k_buf
.
GetDeviceBuffer
(),
args
.
stride_k
=
stride_k
;
shape_seqlen_q
,
args
.
stride_v
=
stride_v
;
shape_seqlen_k
,
args
.
nhead_stride_q
=
nhead_stride_q
;
batch
,
args
.
nhead_stride_k
=
nhead_stride_k
;
max_seqlen_q
,
args
.
nhead_stride_v
=
nhead_stride_v
;
hdim_q
,
args
.
batch_stride_q
=
batch_stride_q
;
hdim_v
,
args
.
batch_stride_k
=
batch_stride_k
;
nhead
,
args
.
batch_stride_v
=
batch_stride_v
;
nhead_k
,
num_splits
,
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_appendkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
scale_s
,
{
scale_p
,
args
.
knew_ptr
=
knew_buf
.
GetDeviceBuffer
();
scale_o
,
args
.
vnew_ptr
=
vnew_buf
.
GetDeviceBuffer
();
stride_q
,
args
.
seqlen_knew
=
seqlen_knew
;
stride_k
,
stride_v
,
args
.
seqlen_k_ptr
=
cache_seqlen_k_buf
.
GetDeviceBuffer
();
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
,
args
.
rotary_cos_ptr
=
(
0
<
rotary_dim
?
rotary_cos_buf
.
GetDeviceBuffer
()
:
nullptr
);
stride_randval
,
args
.
rotary_sin_ptr
=
(
0
<
rotary_dim
?
rotary_sin_buf
.
GetDeviceBuffer
()
:
nullptr
);
stride_o_acc
,
args
.
rotary_dim
=
rotary_dim
;
stride_o
,
args
.
has_mask
=
(
mask
.
type
!=
mask_enum
::
no_mask
);
nhead_stride_q
,
nhead_stride_k
,
args
.
block_table_ptr
=
nhead_stride_v
,
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
nhead_stride_bias
,
args
.
batch_stride_block_table
=
batch_stride_block_table
;
nhead_stride_randval
,
args
.
page_block_size
=
page_block_size
;
nhead_stride_lse
,
nhead_stride_lse_acc
,
args
.
cache_batch_idx
=
nhead_stride_o_acc
,
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
nhead_stride_o
,
batch_stride_q
,
args
.
stride_knew
=
stride_knew
;
batch_stride_k
,
args
.
stride_vnew
=
stride_vnew
;
batch_stride_v
,
args
.
nhead_stride_knew
=
nhead_stride_knew
;
batch_stride_bias
,
args
.
nhead_stride_vnew
=
nhead_stride_vnew
;
batch_stride_randval
,
args
.
batch_stride_knew
=
batch_stride_knew
;
batch_stride_lse
,
args
.
batch_stride_vnew
=
batch_stride_vnew
;
batch_stride_lse_acc
,
}
batch_stride_o_acc
,
else
// fmha_fwd_args or fmha_fwd_splitkv_args
batch_stride_o
,
{
split_stride_lse_acc
,
args
.
bias_ptr
=
bias
.
type
==
bias_enum
::
alibi
?
alibi_slope_buf
.
GetDeviceBuffer
()
split_stride_o_acc
,
:
bias_buf
.
GetDeviceBuffer
();
mask
.
left
,
args
.
lse_ptr
=
lse_buf
.
GetDeviceBuffer
();
mask
.
right
,
args
.
o_ptr
=
o_buf
.
GetDeviceBuffer
();
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
),
p_drop
,
args
.
seqstart_q_ptr
=
s_randval
,
(
mode
==
mode_enum
::
group
?
seqstart_q
.
GetDeviceBuffer
()
:
nullptr
);
{
drop_seed
,
drop_offset
}};
args
.
seqstart_k_ptr
=
(
mode
==
mode_enum
::
group
?
seqstart_k
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k_ptr
=
(
use_kvcache
||
0
<=
k_paddings_
[
0
]
?
seqlen_k_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
seqlen_k
=
shape_seqlen_k
;
// unused in group mode (or kvcache enabled)
args
.
max_seqlen_q
=
max_seqlen_q
;
args
.
scale_s
=
scale_s
;
args
.
scale_p
=
scale_p
;
args
.
scale_o
=
scale_o
;
args
.
stride_bias
=
(
bias
.
type
==
bias_enum
::
alibi
?
(
bias
.
rank_info
==
0
?
0
:
nhead
)
:
stride_bias
);
args
.
stride_o
=
stride_o
;
args
.
nhead_stride_bias
=
nhead_stride_bias
;
args
.
nhead_stride_lse
=
nhead_stride_lse
;
args
.
nhead_stride_o
=
nhead_stride_o
;
args
.
batch_stride_bias
=
batch_stride_bias
;
args
.
batch_stride_lse
=
batch_stride_lse
;
args
.
batch_stride_o
=
batch_stride_o
;
args
.
window_size_left
=
mask
.
left
;
args
.
window_size_right
=
mask
.
right
;
args
.
mask_type
=
static_cast
<
ck_tile
::
index_t
>
(
mask
.
type
);
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
rand_val_ptr
=
randval_buf
.
GetDeviceBuffer
();
args
.
stride_randval
=
stride_randval
;
args
.
nhead_stride_randval
=
nhead_stride_randval
;
args
.
batch_stride_randval
=
batch_stride_randval
;
args
.
p_drop
=
p_drop
;
args
.
s_randval
=
s_randval
;
args
.
drop_seed_offset
=
std
::
tie
(
drop_seed
,
drop_offset
);
}
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
args
.
block_table_ptr
=
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
page_block_size
=
page_block_size
;
args
.
cache_batch_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
num_splits
=
num_splits
;
args
.
stride_o_acc
=
stride_o_acc
;
args
.
nhead_stride_lse_acc
=
nhead_stride_lse_acc
;
args
.
nhead_stride_o_acc
=
nhead_stride_o_acc
;
args
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
}
}
};
const
float
appendkv_ave_time
=
[
&
]
{
#if CK_TILE_FMHA_FWD_APPENDKV_API
if
(
need_append_kvcache
)
{
fmha_fwd_appendkv_traits
fwd_appendkv_traits
;
init_traits
(
fwd_appendkv_traits
);
fmha_fwd_appendkv_args
fwd_appendkv_args
;
init_args
(
fwd_appendkv_args
);
return
fmha_fwd_appendkv
(
fwd_appendkv_traits
,
fwd_appendkv_args
,
stream_config
);
}
#endif
return
0.0
f
;
}();
}();
float
ave_time
=
fmha_fwd_dispatch
(
fmha_traits
,
fmha_args
,
stream_config
);
const
float
fwd_ave_time
=
[
&
]
{
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
1
<
num_splits
||
use_kvcache
)
{
fmha_fwd_splitkv_traits
fmha_splitkv_traits
;
init_traits
(
fmha_splitkv_traits
);
fmha_fwd_splitkv_args
fmha_splitkv_args
;
init_args
(
fmha_splitkv_args
);
return
fmha_fwd_splitkv
(
fmha_splitkv_traits
,
fmha_splitkv_args
,
stream_config
);
}
#endif
fmha_fwd_traits
fmha_traits
;
init_traits
(
fmha_traits
);
fmha_fwd_args
fmha_args
;
init_args
(
fmha_args
);
return
fmha_fwd
(
fmha_traits
,
fmha_args
,
stream_config
);
}();
if
(
ave_time
<
0
)
if
(
appendkv_
ave_time
<
0
.0
f
||
fwd_ave_time
<
0.0
f
)
{
{
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
std
::
cout
<<
", not supported yet"
<<
std
::
flush
<<
std
::
endl
;
return
false
;
return
false
;
}
}
const
float
ave_time
=
(
appendkv_ave_time
+
fwd_ave_time
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
float
gb_per_sec
=
num_byte
/
1.E6
/
ave_time
;
...
@@ -775,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -775,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf
.
FromDevice
(
o_host
.
data
());
o_buf
.
FromDevice
(
o_host
.
data
());
lse_buf
.
FromDevice
(
lse_host
.
data
());
lse_buf
.
FromDevice
(
lse_host
.
data
());
randval_buf
.
FromDevice
(
randval_host
.
data
());
randval_buf
.
FromDevice
(
randval_host
.
data
());
auto
p_compute_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
scales
{
scale_p
};
else
return
ck_tile
::
identity
{};
}();
auto
oacc_element_func
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
DataType
,
ck_tile
::
fp8_t
>
)
return
ck_tile
::
composes
(
ck_tile
::
saturates
<
ck_tile
::
fp8_t
>
{},
ck_tile
::
scales
{
scale_o
});
else
return
ck_tile
::
identity
{};
}();
float
p_undrop
=
1.0
-
p_drop
;
float
p_undrop
=
1.0
-
p_drop
;
uint8_t
p_undrop_in_uint8_t
=
uint8_t
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
float
rp_undrop
=
1.0
/
p_undrop
;
float
rp_undrop
=
1.0
/
p_undrop
;
bool
pass
=
true
;
bool
pass
=
true
;
for
(
ck_tile
::
index_t
wb
=
0
;
wb
<
batch
;
++
wb
)
for
(
ck_tile
::
index_t
wb
=
0
;
wb
<
batch
;
++
wb
)
{
{
const
ck_tile
::
index_t
real_seqlen_q
=
seqstart_q_host
[
wb
+
1
]
-
seqstart_q_host
[
wb
];
const
ck_tile
::
index_t
real_seqlen_q
=
seqstart_q_host
[
wb
+
1
]
-
seqstart_q_host
[
wb
];
const
ck_tile
::
index_t
real_seqlen_k
=
seqstart_k_host
[
wb
+
1
]
-
seqstart_k_host
[
wb
];
const
ck_tile
::
index_t
real_seqlen_k
=
seqstart_k_host
[
wb
+
1
]
-
seqstart_k_host
[
wb
];
// adjust matrix index according to the mode
// adjust matrix index according to the mode
const
ck_tile
::
index_t
b
=
(
mode
==
mode_enum
::
batch
?
wb
:
0
);
const
ck_tile
::
index_t
b_idx
=
(
mode
==
mode_enum
::
batch
?
wb
:
0
);
const
ck_tile
::
index_t
cache_b_idx
=
(
use_cache_batch_idx
?
cache_batch_idx_host
(
b_idx
)
:
b_idx
);
const
ck_tile
::
index_t
query_offset
=
(
mode
==
mode_enum
::
batch
?
0
:
seqstart_q_host
[
wb
]);
const
ck_tile
::
index_t
query_offset
=
(
mode
==
mode_enum
::
batch
?
0
:
seqstart_q_host
[
wb
]);
const
ck_tile
::
index_t
key_offset
=
const
ck_tile
::
index_t
key_offset
=
(
mode
==
mode_enum
::
batch
(
mode
==
mode_enum
::
batch
?
0
?
0
:
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
[
wb
]
:
seqstart_k_with_padding_host
[
wb
]));
:
(
seqlen_kpads
[
0
]
<
0
?
seqstart_k_host
[
wb
]
:
seqstart_k_with_padding_host
[
wb
]));
const
auto
v_host_ref_lengths
=
std
::
array
<
ck_tile
::
index_t
,
3
>
{
nhead
,
hdim_v
,
real_seqlen_k
};
const
auto
v_host_ref_strides
=
is_v_rowmajor
?
std
::
array
<
ck_tile
::
index_t
,
3
>
{
hdim_v
*
real_seqlen_k
,
1
,
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
3
>
{
hdim_v
*
real_seqlen_k
,
real_seqlen_k
,
1
};
ck_tile
::
HostTensor
<
QDataType
>
q_host_ref
({
nhead
,
real_seqlen_q
,
hdim_q
});
ck_tile
::
HostTensor
<
QDataType
>
q_host_ref
({
nhead
,
real_seqlen_q
,
hdim_q
});
ck_tile
::
HostTensor
<
KDataType
>
k_host_ref
({
nhead
,
real_seqlen_k
,
hdim_q
});
ck_tile
::
HostTensor
<
KDataType
>
k_host_ref
({
nhead
,
real_seqlen_k
,
hdim_q
});
ck_tile
::
HostTensor
<
VDataType
>
v_host_ref
(
v_host_ref_lengths
,
v_host_ref_strides
);
ck_tile
::
HostTensor
<
VDataType
>
v_host_ref
(
{
nhead
,
hdim_v
,
real_seqlen_k
}
);
ck_tile
::
HostTensor
<
ODataType
>
o_host_ref
({
nhead
,
real_seqlen_q
,
hdim_v
});
ck_tile
::
HostTensor
<
ODataType
>
o_host_ref
({
nhead
,
real_seqlen_q
,
hdim_v
});
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
s_host_ref
({
nhead
,
real_seqlen_q
,
real_seqlen_k
});
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
s_host_ref
({
nhead
,
real_seqlen_q
,
real_seqlen_k
});
...
@@ -815,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -815,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off
// clang-format off
// permute
// permute
if
(
i_perm
)
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b
,
i
[
0
],
i
[
1
]
+
query_offset
,
i
[
2
]);
});
if
(
i_perm
)
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b_idx
,
i
[
0
],
i
[
1
]
+
query_offset
,
i
[
2
]);
});
else
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b
,
i
[
1
]
+
query_offset
,
i
[
0
],
i
[
2
]);
});
else
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host
(
b_idx
,
i
[
1
]
+
query_offset
,
i
[
0
],
i
[
2
]);
});
#if CK_TILE_FMHA_FWD_APPENDKV_API
// optionally apply RoPE to the q_host_ref
if
(
0
<
rotary_dim
)
{
decltype
(
q_host_ref
)
q_host_ref_ro
(
q_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
real_seqlen_q
);
ck_tile
::
reference_batched_rotary_position_embedding
(
q_host_ref
,
rotary_cos_slice
,
rotary_sin_slice
,
is_rotary_interleaved
,
q_host_ref_ro
,
/*use_1_row_sin_cos=*/
mask
.
type
==
mask_enum
::
no_mask
);
q_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
q_host_ref_ro
(
i
);
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
i_perm
)
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
]
%
page_block_size
,
i
[
2
]);
});
}
else
{
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
block_table_host
(
wb
,
i
[
1
]
/
page_block_size
),
i
[
1
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
}
else
#endif
{
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
cache_b_idx
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Knew to the end of K
if
(
0
<
seqlen_knew
)
{
ck_tile
::
HostTensor
<
KDataType
>
knew_host_ref
({
nhead
,
seqlen_knew
,
hdim_q
});
if
(
i_perm
)
knew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
knew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]);
});
else
knew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
knew_host
(
wb
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]);
});
// optionally apply RoPE to the knew_host_ref
auto
*
real_knew_host_ref
=
&
knew_host_ref
;
std
::
optional
<
decltype
(
knew_host_ref
)
>
knew_host_ref_ro
;
if
(
0
<
rotary_dim
)
{
knew_host_ref_ro
.
emplace
(
knew_host_ref
.
get_lengths
());
auto
[
rotary_cos_slice
,
rotary_sin_slice
]
=
slice_rotary_cos_sin
(
rotary_cos_host
,
rotary_sin_host
,
cache_seqlen_ks
[
wb
],
seqlen_knew
);
if
(
i_perm
)
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
b
,
i
[
0
]
/
nr
,
i
[
1
]
+
key_offset
,
i
[
2
]);
});
ck_tile
::
reference_batched_rotary_position_embedding
(
else
k_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
k_host
(
b
,
i
[
1
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
2
]);
});
knew_host_ref
,
rotary_cos_slice
,
rotary_sin_slice
,
is_rotary_interleaved
,
knew_host_ref_ro
.
value
());
real_knew_host_ref
=
&
knew_host_ref_ro
.
value
();
}
if
(
is_v_rowmajor
)
{
(
*
real_knew_host_ref
).
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
k_host_ref
(
i
[
0
],
i
[
1
]
+
cache_seqlen_ks
[
wb
],
i
[
2
])
=
self
(
i
);
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if
(
0
<
page_block_size
)
{
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
,
i
[
1
]);
});
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
2
]
%
page_block_size
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
}
else
{
if
(
i_perm
)
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
%
page_block_size
);
});
}
else
{
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
block_table_host
(
wb
,
i
[
2
]
/
page_block_size
),
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
%
page_block_size
);
});
}
}
}
else
#endif
{
if
(
is_v_rowmajor
)
{
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
,
i
[
1
]);
});
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
2
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
1
]);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
2
]
+
key_offset
,
i
[
0
]
/
nr
,
i
[
1
]);
});
}
}
else
{
else
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
{
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
b
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
);
});
if
(
i_perm
)
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]
+
key_offset
);
});
else
v_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
v_host
(
cache_b_idx
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]
+
key_offset
);
});
}
}
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Vnew to the end of V
if
(
0
<
seqlen_knew
)
{
ck_tile
::
HostTensor
<
VDataType
>
vnew_host_ref
({
nhead
,
hdim_v
,
seqlen_knew
});
if
(
is_v_rowmajor
)
{
if
(
i_perm
)
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
2
],
i
[
1
]);
});
else
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
2
],
i
[
0
]
/
nr
,
i
[
1
]);
});
}
else
{
if
(
i_perm
)
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
0
]
/
nr
,
i
[
1
],
i
[
2
]);
});
else
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
vnew_host
(
wb
,
i
[
1
],
i
[
0
]
/
nr
,
i
[
2
]);
});
}
vnew_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
v_host_ref
(
i
[
0
],
i
[
1
],
i
[
2
]
+
cache_seqlen_ks
[
wb
])
=
self
(
i
);
});
}
#endif
// clang-format on
// clang-format on
// reference
// reference
...
@@ -959,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -959,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
RandValOutputDataType
>
randval_host_ref
(
ck_tile
::
HostTensor
<
RandValOutputDataType
>
randval_host_ref
(
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
{
nhead
,
real_seqlen_q
,
real_seqlen_k
});
randval_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
randval_host_ref
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
randval_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
self
(
idx
)
=
randval_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
});
ck_tile
::
reference_batched_dropout
(
ck_tile
::
reference_batched_dropout
(
p_host_ref
,
randval_host_ref
,
p_undrop_in_uint8_t
,
rp_undrop
);
p_host_ref
,
randval_host_ref
,
p_undrop_in_uint8_t
,
rp_undrop
);
...
@@ -976,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -976,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile
::
HostTensor
<
ODataType
>
o_host_result
({
nhead
,
real_seqlen_q
,
hdim_v
});
ck_tile
::
HostTensor
<
ODataType
>
o_host_result
({
nhead
,
real_seqlen_q
,
hdim_v
});
// clang-format off
// clang-format off
// permute
// permute
if
(
o_perm
)
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
if
(
o_perm
)
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
,
idx
[
2
]);
});
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
else
o_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
o_host
(
b
_idx
,
idx
[
1
]
+
query_offset
,
idx
[
0
],
idx
[
2
]);
});
// clang-format on
// clang-format on
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
init_method
);
auto
[
rtol
,
atol
]
=
get_elimit
<
DataType
>
(
init_method
);
...
@@ -999,7 +1439,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -999,7 +1439,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
{
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
lse_host_result
({
nhead
,
real_seqlen_q
});
ck_tile
::
HostTensor
<
SMPLComputeDataType
>
lse_host_result
({
nhead
,
real_seqlen_q
});
lse_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
lse_host_result
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
lse_host
(
b
,
idx
[
0
],
idx
[
1
]
+
query_offset
);
self
(
idx
)
=
lse_host
(
b
_idx
,
idx
[
0
],
idx
[
1
]
+
query_offset
);
});
});
cur_pass
=
ck_tile
::
check_err
(
lse_host_result
,
cur_pass
=
ck_tile
::
check_err
(
lse_host_result
,
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
d71189ff
...
@@ -5,10 +5,13 @@
...
@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"
#include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits>
#include <type_traits>
template
<
typename
DataType
>
template
<
typename
DataType
>
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
...
@@ -93,13 +96,86 @@ struct fmha_fwd_args
const
void
*
v_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
rand_val_ptr
;
void
*
rand_val_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
// only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
float
scale_s
;
float
scale_p
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
};
struct
fmha_fwd_splitkv_args
{
const
void
*
q_ptr
;
const
void
*
k_ptr
;
const
void
*
v_ptr
;
const
void
*
bias_ptr
;
// bias or alibi_slope pointer
void
*
lse_acc_ptr
;
void
*
lse_acc_ptr
;
void
*
o_acc_ptr
;
void
*
o_acc_ptr
;
void
*
lse_ptr
;
void
*
lse_ptr
;
void
*
o_ptr
;
void
*
o_ptr
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_q_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqstart_k_ptr
;
const
void
*
seqlen_k_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
batch
;
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
...
@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
nhead_k
;
ck_tile
::
index_t
num_splits
;
ck_tile
::
index_t
num_splits
;
float
scale_s
;
float
scale_s
;
float
scale_p
;
float
scale_p
;
float
scale_o
;
float
scale_o
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_bias
;
// if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile
::
index_t
stride_randval
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o_acc
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
stride_o
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_bias
;
ck_tile
::
index_t
nhead_stride_randval
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
...
@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_bias
;
ck_tile
::
index_t
batch_stride_randval
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_left
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
window_size_right
;
ck_tile
::
index_t
mask_type
;
ck_tile
::
index_t
mask_type
;
float
p_drop
;
};
bool
s_randval
;
std
::
tuple
<
uint64_t
,
uint64_t
>
drop_seed_offset
;
struct
fmha_fwd_appendkv_args
{
void
*
q_ptr
;
void
*
k_ptr
;
const
void
*
knew_ptr
;
void
*
v_ptr
;
const
void
*
vnew_ptr
;
const
void
*
seqlen_k_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_knew
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
nhead_q
;
ck_tile
::
index_t
nhead_k
;
const
void
*
rotary_cos_ptr
;
// only used if 'rotary_dim' > 0
const
void
*
rotary_sin_ptr
;
// only used if 'rotary_dim' > 0
ck_tile
::
index_t
rotary_dim
;
bool
has_mask
;
void
*
block_table_ptr
;
ck_tile
::
index_t
batch_stride_block_table
;
// only used if 'block_table_ptr' is not nullptr
ck_tile
::
index_t
page_block_size
;
// only used if 'block_table_ptr' is not nullptr
const
void
*
cache_batch_idx
;
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_knew
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_vnew
;
ck_tile
::
index_t
nhead_stride_q
;
ck_tile
::
index_t
nhead_stride_k
;
ck_tile
::
index_t
nhead_stride_knew
;
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_vnew
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_knew
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_vnew
;
};
};
template
<
typename
FmhaKernel
>
template
<
typename
FmhaKernel
>
...
@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}
template
<
typename
Kernel
>
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
auto
kargs
=
[
&
]
{
...
@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqstart_q_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
seqlen_k_ptr
,
args
.
seqlen_k_ptr
,
...
@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
);
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
...
@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args
.
k_ptr
,
args
.
k_ptr
,
args
.
v_ptr
,
args
.
v_ptr
,
args
.
bias_ptr
,
args
.
bias_ptr
,
args
.
rand_val_ptr
,
args
.
lse_acc_ptr
,
args
.
lse_acc_ptr
,
args
.
o_acc_ptr
,
args
.
o_acc_ptr
,
args
.
batch
,
args
.
batch
,
args
.
max_seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_q
,
args
.
seqlen_k
,
args
.
seqlen_k
,
args
.
seqlen_k_ptr
,
args
.
hdim_q
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
num_splits
,
args
.
num_splits
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
scale_s
,
args
.
scale_s
,
args
.
scale_p
,
args
.
scale_p
,
args
.
stride_q
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_k
,
args
.
stride_v
,
args
.
stride_v
,
args
.
stride_bias
,
args
.
stride_bias
,
args
.
stride_randval
,
args
.
stride_o_acc
,
args
.
stride_o_acc
,
args
.
nhead_stride_q
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_k
,
args
.
nhead_stride_v
,
args
.
nhead_stride_v
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_randval
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_o_acc
,
args
.
nhead_stride_o_acc
,
args
.
batch_stride_q
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_k
,
args
.
batch_stride_v
,
args
.
batch_stride_v
,
args
.
batch_stride_bias
,
args
.
batch_stride_bias
,
args
.
batch_stride_randval
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_o_acc
,
args
.
batch_stride_o_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_lse_acc
,
args
.
split_stride_o_acc
,
args
.
split_stride_o_acc
,
args
.
window_size_left
,
args
.
window_size_left
,
args
.
window_size_right
,
args
.
window_size_right
,
args
.
mask_type
,
args
.
mask_type
);
args
.
p_drop
,
args
.
s_randval
,
args
.
drop_seed_offset
);
}
}
}();
}();
...
@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
}
}
template
<
typename
Kernel
>
template
<
typename
Kernel
>
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_args
args
)
auto
fmha_fwd_splitkv_combine_create_kargs_and_grids
(
fmha_fwd_
splitkv_
args
args
)
{
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
[
&
]
{
auto
kargs
=
[
&
]
{
...
@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
...
@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
}
template
<
typename
Kernel
>
auto
fmha_fwd_appendkv_create_kargs_and_grids
(
fmha_fwd_appendkv_args
args
)
{
assert
(
args
.
nhead_q
%
args
.
nhead_k
==
0
);
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
q_ptr
,
args
.
k_ptr
,
args
.
knew_ptr
,
args
.
v_ptr
,
args
.
vnew_ptr
,
args
.
seqlen_q
,
args
.
seqlen_k_ptr
,
args
.
seqlen_knew
,
args
.
hdim_q
,
args
.
hdim_v
,
args
.
nhead_q
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
rotary_cos_ptr
,
args
.
rotary_sin_ptr
,
args
.
rotary_dim
,
args
.
has_mask
,
args
.
block_table_ptr
,
args
.
batch_stride_block_table
,
args
.
page_block_size
,
args
.
cache_batch_idx
,
args
.
stride_q
,
args
.
stride_k
,
args
.
stride_knew
,
args
.
stride_v
,
args
.
stride_vnew
,
args
.
nhead_stride_q
,
args
.
nhead_stride_k
,
args
.
nhead_stride_knew
,
args
.
nhead_stride_v
,
args
.
nhead_stride_vnew
,
args
.
batch_stride_q
,
args
.
batch_stride_k
,
args
.
batch_stride_knew
,
args
.
batch_stride_v
,
args
.
batch_stride_vnew
);
dim3
grids
=
Kernel
::
GridSize
(
args
.
batch
,
args
.
nhead_q
,
args
.
seqlen_q
,
args
.
seqlen_knew
);
return
ck_tile
::
make_tuple
(
kargs
,
grids
);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
typename
DataType_
,
...
@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
...
@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
template
<
typename
Traits_
>
template
<
typename
Traits_
>
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
float
fmha_fwd_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
bool
kIsGroupMode_
,
ck_tile
::
index_t
kM0_
,
ck_tile
::
index_t
kN0_
,
ck_tile
::
index_t
kK0_
,
ck_tile
::
index_t
kN1_
,
ck_tile
::
index_t
kK1_
,
ck_tile
::
index_t
kK0BlockLength_
,
bool
kIsVLayoutRowMajor_
,
ck_tile
::
BlockFmhaPipelineEnum
FmhaPipelineEnum_
,
typename
FmhaMask_
,
ck_tile
::
BlockAttentionBiasEnum
BiasEnum_
,
bool
kStoreLse_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kPadS_
,
bool
kPadSK_
,
bool
kPadD_
,
bool
kPadDv_
>
struct
fmha_fwd_splitkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kK1
=
kK1_
;
static
constexpr
ck_tile
::
index_t
kK0BlockLength
=
kK0BlockLength_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
auto
FmhaPipelineEnum
=
FmhaPipelineEnum_
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
FmhaMask_
>
;
static
constexpr
auto
BiasEnum
=
BiasEnum_
;
static
constexpr
bool
kStoreLse
=
kStoreLse_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSK
=
kPadSK_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_get_name_
();
std
::
string
fmha_fwd_splitkv_get_name_
();
...
@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
...
@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
};
};
template
<
typename
Traits_
>
template
<
typename
Traits_
>
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_args
);
void
fmha_fwd_splitkv_combine_oneshot_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_
splitkv_
args
);
template
<
typename
Traits_
>
template
<
typename
Traits_
>
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
std
::
string
fmha_fwd_splitkv_combine_get_name_
();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template
<
ck_tile
::
index_t
HDim_
,
typename
DataType_
,
ck_tile
::
index_t
kTileSizeS_
,
ck_tile
::
index_t
kTileSizeSk_
,
ck_tile
::
index_t
kTileSizeD_
,
ck_tile
::
index_t
kTileSizeDv_
,
bool
kIsVLayoutRowMajor_
,
bool
kPadS_
,
bool
kPadSk_
,
bool
kPadD_
,
bool
kPadDv_
,
ck_tile
::
RotaryEmbeddingEnum
RotaryEnum_
,
bool
kIsPagedKV_
>
struct
fmha_fwd_appendkv_traits_
{
static
constexpr
ck_tile
::
index_t
HDim
=
HDim_
;
using
DataType
=
ck_tile
::
remove_cvref_t
<
DataType_
>
;
static
constexpr
ck_tile
::
index_t
kTileSizeS
=
kTileSizeS_
;
static
constexpr
ck_tile
::
index_t
kTileSizeSk
=
kTileSizeSk_
;
static
constexpr
ck_tile
::
index_t
kTileSizeD
=
kTileSizeD_
;
static
constexpr
ck_tile
::
index_t
kTileSizeDv
=
kTileSizeDv_
;
static
constexpr
bool
kIsVLayoutRowMajor
=
kIsVLayoutRowMajor_
;
static
constexpr
bool
kPadS
=
kPadS_
;
static
constexpr
bool
kPadSk
=
kPadSk_
;
static
constexpr
bool
kPadD
=
kPadD_
;
static
constexpr
bool
kPadDv
=
kPadDv_
;
static
constexpr
auto
RotaryEnum
=
RotaryEnum_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
};
template
<
typename
Traits_
>
float
fmha_fwd_appendkv_
(
const
ck_tile
::
stream_config
&
,
fmha_fwd_appendkv_args
);
// This is the public API, will be generated by script
// This is the public API, will be generated by script
struct
fmha_fwd_traits
struct
fmha_fwd_traits
{
{
...
@@ -508,4 +743,32 @@ struct fmha_fwd_traits
...
@@ -508,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api
// TODO: padding check is inside this api
};
};
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
float
fmha_fwd_splitkv
(
fmha_fwd_traits
,
fmha_fwd_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_splitkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_group_mode
;
bool
is_v_rowmajor
;
mask_enum
mask_type
;
bias_enum
bias_type
;
// 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool
has_lse
;
bool
do_fp8_static_quant
;
// TODO: padding check is inside this api
};
float
fmha_fwd_splitkv
(
fmha_fwd_splitkv_traits
,
fmha_fwd_splitkv_args
,
const
ck_tile
::
stream_config
&
);
struct
fmha_fwd_appendkv_traits
{
int
hdim_q
;
int
hdim_v
;
std
::
string
data_type
;
bool
is_v_rowmajor
;
rope_enum
rope_type
;
};
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/01_fmha/generate.py
View file @
d71189ff
...
@@ -5,25 +5,30 @@
...
@@ -5,25 +5,30 @@
import
argparse
import
argparse
from
enum
import
IntEnum
from
enum
import
IntEnum
from
pathlib
import
Path
from
pathlib
import
Path
import
pkgutil
import
sys
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
codegen.ops
from
codegen.cmake_config
import
*
from
codegen.cmake_config
import
*
from
codegen.ops
import
(
fmha_fwd
,
fmha_fwd_splitkv
,
fmha_bwd
)
class
HandlerId
(
IntEnum
):
class
HandlerId
(
IntEnum
):
LIST_BLOBS
=
0
LIST_BLOBS
=
0
WRITE_BLOBS
=
1
WRITE_BLOBS
=
1
handlers
=
{
# inspect all modules under 'codegen.ops' and register API handlers
'fwd'
:
(
fmha_fwd
.
list_blobs
,
fmha_fwd
.
write_blobs
),
ops
=
[]
'fwd_splitkv'
:
(
fmha_fwd_splitkv
.
list_blobs
,
fmha_fwd_splitkv
.
write_blobs
),
for
importer
,
module_name
,
_
in
pkgutil
.
iter_modules
(
codegen
.
ops
.
__path__
):
'bwd'
:
(
fmha_bwd
.
list_blobs
,
fmha_bwd
.
write_blobs
),
full_module_name
=
'%s.%s'
%
(
codegen
.
ops
.
__name__
,
module_name
)
}
if
full_module_name
not
in
sys
.
modules
:
ops
.
append
(
importer
.
find_spec
(
module_name
).
loader
.
load_module
(
module_name
))
unwanted_prefix
=
'fmha_'
handlers
=
dict
(
[(
op
.
__name__
[
len
(
unwanted_prefix
):]
if
op
.
__name__
.
startswith
(
unwanted_prefix
)
else
op
.
__name__
,
(
op
.
list_blobs
,
op
.
write_blobs
))
for
op
in
ops
]
)
assert
0
<
len
(
handlers
)
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
def
write_blobs
(
output_dir
:
Optional
[
str
],
api_list
:
List
[
str
],
kernel_filter
:
Optional
[
str
],
receipt
,
mask_impl
)
->
None
:
if
output_dir
is
None
:
if
output_dir
is
None
:
...
...
example/ck_tile/01_fmha/rotary.hpp
0 → 100644
View file @
d71189ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum
class
rope_enum
{
none
=
0
,
interleaved
=
1
,
half_rotated
=
2
,
};
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
generate_rotary_cos_sin
(
ck_tile
::
index_t
seqlen
,
ck_tile
::
index_t
rotary_dim
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
// return dummy tensors if we won't apply RoPE at all
if
(
rotary_dim
<=
0
)
{
ck_tile
::
HostTensor
<
DataType
>
dummy
({
1
,
1
});
return
std
::
make_tuple
(
dummy
,
dummy
);
}
std
::
mt19937
random_engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
generator
(
0.0
f
,
1.0
f
);
const
ck_tile
::
index_t
num_rows
=
seqlen
*
2
;
const
ck_tile
::
index_t
num_cols
=
rotary_dim
/
2
;
using
std
::
begin
,
std
::
end
;
ck_tile
::
HostTensor
<
float
>
angle
({
num_rows
,
num_cols
});
std
::
generate
(
begin
(
angle
),
end
(
angle
),
[
&
]
{
return
generator
(
random_engine
)
*
2
*
M_PI
;
});
ck_tile
::
HostTensor
<
DataType
>
cos
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
cos
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
cos
(
origin_value
));
});
ck_tile
::
HostTensor
<
DataType
>
sin
({
num_rows
,
num_cols
});
std
::
transform
(
begin
(
angle
),
end
(
angle
),
begin
(
sin
),
[](
float
origin_value
)
{
return
ck_tile
::
type_convert
<
DataType
>
(
std
::
sin
(
origin_value
));
});
return
std
::
make_tuple
(
cos
,
sin
);
}
template
<
typename
DataType
>
std
::
tuple
<
ck_tile
::
HostTensor
<
DataType
>
,
ck_tile
::
HostTensor
<
DataType
>>
slice_rotary_cos_sin
(
const
ck_tile
::
HostTensor
<
DataType
>&
cos
,
const
ck_tile
::
HostTensor
<
DataType
>&
sin
,
ck_tile
::
index_t
seqlen_offset
,
ck_tile
::
index_t
seqlen
)
{
assert
(
cos
.
get_num_of_dimension
()
==
2
&&
sin
.
get_num_of_dimension
()
==
2
);
assert
(
cos
.
get_length
(
0
)
==
sin
.
get_length
(
0
)
&&
cos
.
get_length
(
1
)
==
sin
.
get_length
(
1
));
assert
(
static_cast
<
std
::
size_t
>
(
seqlen_offset
+
seqlen
)
<=
cos
.
get_length
(
0
));
const
ck_tile
::
index_t
num_rows
=
seqlen
;
const
ck_tile
::
index_t
num_cols
=
cos
.
get_length
(
1
);
ck_tile
::
HostTensor
<
DataType
>
cos_pt
({
num_rows
,
num_cols
});
cos_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
cos
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
ck_tile
::
HostTensor
<
DataType
>
sin_pt
({
num_rows
,
num_cols
});
sin_pt
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
self
(
i
)
=
sin
(
i
[
0
]
+
seqlen_offset
,
i
[
1
]);
});
return
std
::
make_tuple
(
cos_pt
,
sin_pt
);
}
example/ck_tile/01_fmha/script/benchmark_bwd.sh
View file @
d71189ff
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
VALID
=
0
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/benchmark_fwd.sh
View file @
d71189ff
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
VALID
=
0
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
for
prec
in
"fp16"
"bf16"
;
do
...
...
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
View file @
d71189ff
#!/bin/sh
#!/bin/sh
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_bwd
KNAME
=
1
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_WARMUP
=
0
...
...
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
View file @
d71189ff
#!/bin/sh
#!/bin/bash
# TODO: run this script from CK root
# TODO: run this script from CK root or build directory
BUILD
=
build
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
EXE
=
$BUILD
/bin/tile_example_fmha_fwd
KNAME
=
1
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_WARMUP
=
0
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
...
@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
# mode=0
# mode=0
# export HIP_VISIBLE_DEVICES=4
# export HIP_VISIBLE_DEVICES=4
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-kname
=
$KNAME
$COMMON_ARGS
done
TEST_SPLITKV
=
0
done
TEST_APPENDKV
=
0
done
# options:
done
# -s: run splitkv tests
done
# -a: run appendkv tests
done
while
getopts
":sa"
opt
;
do
done
case
"
${
opt
}
"
in
s
)
TEST_SPLITKV
=
1
;;
a
)
TEST_APPENDKV
=
1
;;
*
)
;;
esac
done
done
run_fp16_bf16_tests
()
{
local
NUM_SPLITS
=(
1
)
local
PAGE_BLOCK_SIZE
=(
0
)
local
CACHE_BATCH_IDX
=(
0
)
if
[
$TEST_SPLITKV
-eq
1
]
;
then
NUM_SPLITS+
=(
2 3
)
PAGE_BLOCK_SIZE+
=(
128
)
CACHE_BATCH_IDX+
=(
1
)
fi
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
for
num_splits
in
"
${
NUM_SPLITS
[@]
}
"
;
do
for
page_block_size
in
"
${
PAGE_BLOCK_SIZE
[@]
}
"
;
do
for
cache_batch_idx
in
"
${
CACHE_BATCH_IDX
[@]
}
"
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
;
done
;
done
done
;
}
run_fp8_tests
()
{
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
}
run_fp16_appendkv_tests
()
{
for
s
in
$(
seq
63 1 65
)
;
do
for
s_k
in
65 129
;
do
for
s_knew
in
0 64
$s_k
;
do
for
hdim
in
32 64 128 256
;
do
for
ri
in
0 1
;
do
for
rdim
in
0 16 32
$hdim
;
do
for
page_block_size
in
0 128
;
do
for
cache_batch_idx
in
0 1
;
do
$EXE
-prec
=
fp16
-b
=
3
-h
=
3
-d
=
$hdim
-s
=
$s
-s_k
=
$s_k
-s_knew
=
$s_knew
-rotary_dim
=
$rdim
-rotary_interleaved
=
$ri
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-iperm
=
1
-operm
=
1
-kname
=
1
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
}
set
-x
run_fp16_bf16_tests
run_fp8_tests
if
[
$TEST_APPENDKV
-eq
1
]
;
then
run_fp16_appendkv_tests
fi
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
set
+x
set
+x
\ No newline at end of file
example/ck_tile/01_fmha/utils.hpp
View file @
d71189ff
...
@@ -3,15 +3,17 @@
...
@@ -3,15 +3,17 @@
#pragma once
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdint>
#include <cstdlib>
#include <cstdlib>
#include <functional>
#include <optional>
#include <optional>
#include <ostream>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <tuple>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/span.hpp"
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
...
@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
assert
(
0
<
count
);
assert
(
0
<
count
);
std
::
vector
<
int32_t
>
seqlens
(
seqlen_min
=
(
0
<
seqlen_min
?
seqlen_min
:
1
);
count
,
seqlen_max
>
0
?
(
seqlen_avg
<
seqlen_max
?
seqlen_avg
:
seqlen_max
)
:
seqlen_avg
);
seqlen_max
=
(
0
<
seqlen_max
?
seqlen_max
:
std
::
numeric_limits
<
int32_t
>::
max
());
assert
(
seqlen_min
<=
seqlen_max
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
{
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
...
@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
{
const
size_type
to_decrease
=
next_idx
();
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is
always greater than 0
// make sure each elements of seqlens is
in range [seqlen_min, seqlen_max]
if
(
seqlens
[
to_decrease
]
==
1
)
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
{
{
continue
;
continue
;
}
}
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
if
(
seqlen_max
>
0
&&
seqlens
[
to_increase
]
>=
seqlen_max
)
if
(
seqlens
[
to_increase
]
>=
seqlen_max
)
{
{
continue
;
continue
;
}
}
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
...
@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
int32_t
seqlen_max
=
-
1
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_max
,
seed
));
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_min
,
seqlen_max
,
seed
));
}
// return random integer generated uniformly in range [low, high]
template
<
typename
Int
=
int
>
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>
,
Int
>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
return
dist
(
engine
);
}
// return random integers generated uniformly in range [low, high]
template
<
typename
Int
,
typename
ForwardIterator
>
auto
randints
(
ForwardIterator
first
,
ForwardIterator
last
,
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
std
::
generate
(
first
,
last
,
[
&
]
{
return
dist
(
engine
);
});
}
}
/*
/*
...
@@ -112,6 +144,8 @@ decode_seqlen(mode_enum mode,
...
@@ -112,6 +144,8 @@ decode_seqlen(mode_enum mode,
std
::
string
q_val
,
std
::
string
q_val
,
std
::
string
k_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
std
::
string
k_pad_val
,
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
use_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...
@@ -119,9 +153,36 @@ decode_seqlen(mode_enum mode,
...
@@ -119,9 +153,36 @@ decode_seqlen(mode_enum mode,
{
{
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
k
<
0
?
q
:
k
);
auto
s_k
=
[
&
]
{
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
use_kvcache
)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
seqlen_ks
.
end
(),
seqlen_k_min
,
seqlen_k_max
,
seed
);
return
seqlen_ks
;
}
return
seqlen_ks
;
}();
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
}
else
else
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
...
@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
s_q
.
push_back
(
q
);
s_q
.
push_back
(
q
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_kpad
.
push_back
(
kp
);
s_kpad
.
push_back
(
kp
);
// s_k should be greater than or equal to seqlen_k_min
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
idx
++
;
idx
++
;
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
{
{
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
...
@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
}
}
if
(
idx
<
batch
)
if
(
idx
<
batch
)
{
{
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
...
@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
r
=
std
::
atoi
(
v
);
r
=
std
::
atoi
(
v
);
return
r
;
return
r
;
}
}
template
<
typename
RandomAccessIterator
,
typename
Int
>
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
iota_shuffle
(
RandomAccessIterator
first
,
RandomAccessIterator
last
,
Int
value
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
std
::
iota
(
first
,
last
,
value
);
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
shuffle
(
first
,
last
,
engine
);
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
d71189ff
...
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
...
@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return
false
;
return
false
;
if
constexpr
(
!
((
NDimSpatial
==
1
&&
if
constexpr
(
!
((
NDimSpatial
==
1
&&
(
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
(
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
2
&&
(
NDimSpatial
==
2
&&
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
||
(
NDimSpatial
==
3
&&
(
NDimSpatial
==
3
&&
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp
View file @
d71189ff
...
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
...
@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
}
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
d71189ff
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
...
@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
index_t
NumGroupsToMerge
=
1
,
index_t
NumGroupsToMerge
=
1
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
typename
ComputeTypeB
=
ComputeTypeA
,
index_t
TransposeTransferSrcScalarPerVector
=
1
,
index_t
TransposeTransferDstScalarPerVector
=
1
>
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
InLayout
,
...
@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using
BDataType
=
InDataType
;
using
BDataType
=
InDataType
;
using
EDataType
=
WeiDataType
;
using
EDataType
=
WeiDataType
;
// If NGCHW then ADataType must be equal to BDataType
static_assert
(
!
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
||
is_same_v
<
ADataType
,
BDataType
>
);
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
AElementwiseOperation
=
OutElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
BElementwiseOperation
=
InElementwiseOperation
;
using
CDEElementwiseOperation
=
WeiElementwiseOperation
;
using
CDEElementwiseOperation
=
WeiElementwiseOperation
;
...
@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch
)[
I2
];
batch
)[
I2
];
}
}
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
using
InputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeInputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
OutputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeOutputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
...
@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
>
;
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
GridwiseElementwise
=
using
GridwiseElementwise
Cast
=
GridwiseElementwise
<
Tuple
<
CElementwiseGridDesc_M_N
>
,
GridwiseElementwise
<
Tuple
<
CElementwiseGridDesc_M_N
>
,
Tuple
<
CElementwiseGridDesc_M_N
>
,
Tuple
<
CElementwiseGridDesc_M_N
>
,
Tuple
<
const
AccDataType
*>
,
Tuple
<
const
AccDataType
*>
,
...
@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1
,
I1
,
I1
>
;
I1
>
;
using
GridwiseElementwiseTranspose
=
GridwiseElementwise
<
Tuple
<
InputTransposeDescType
>
,
Tuple
<
OutputTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
BlockSize
,
MPerBlock
,
NPerBlock
,
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
TransposeTransferSrcScalarPerVector
>
,
Sequence
<
TransposeTransferDstScalarPerVector
>
,
I1
,
I0
>
;
// Argument
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end
(
a_g_n_k_wos_lengths
),
end
(
a_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
begin
(
output_spatial_lengths_
));
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_n_c_wis_strides_transposed
=
b_g_n_c_wis_strides
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_transposed
=
a_g_n_k_wos_strides
;
// NGKHW - transpose needed
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
b_g_n_c_wis_strides_transposed
[
I0
]
=
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
a_g_n_k_wos_strides_transposed
[
I0
]
=
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_K_
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_K_
;
}
}
const
auto
descs
=
const
auto
descs
=
conv_to_gemm_transformer_v2
conv_to_gemm_transformer_v2
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>(
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>(
...
@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_spatial_lengths_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
output_spatial_lengths_
,
b_g_n_c_wis_strides
,
b_g_n_c_wis_strides
_transposed
,
e_g_k_c_xs_strides
,
e_g_k_c_xs_strides
,
a_g_n_k_wos_strides
,
a_g_n_k_wos_strides
_transposed
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const
index_t
GemmN
=
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
);
// A/B/C Batch Stride
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
_transposed
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_n_c_wis_strides
_transposed
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
Conv_K_
*
Conv_C_
*
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
std
::
accumulate
(
begin
(
filter_spatial_lengths_
),
...
@@ -553,13 +750,58 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -553,13 +750,58 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ce_grid_desc_m_n_
,
ce_grid_desc_m_n_
,
GridwiseGemm
::
CalculateMBlock
(
GemmM
),
GridwiseGemm
::
CalculateMBlock
(
GemmM
),
GridwiseGemm
::
CalculateNBlock
(
GemmN
));
GridwiseGemm
::
CalculateNBlock
(
GemmN
));
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
a_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
a_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
b_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
b_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_b_
=
Block2TileMapElementwise
{
b_in_transpose_desc_
.
GetLength
(
I0
),
b_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceBTensorSizeBytes
()
const
{
return
sizeof
(
BDataType
)
*
b_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
{
return
sizeof
(
AccDataType
)
*
ce_grid_desc_m_n_
.
GetElementSpaceSize
()
*
Conv_G_
;
return
sizeof
(
AccDataType
)
*
ce_grid_desc_m_n_
.
GetElementSpaceSize
()
*
Conv_G_
;
}
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceBTensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
GetWorkspaceETensorSizeBytes
();
}
}
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
...
@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_b_
;
InputTransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
OutputTransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
// for computing batch offset
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
...
@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
AccDataType
*
p_c_grid
=
type_convert
<
AccDataType
*>
(
arg
.
p_workspace_
);
AccDataType
*
p_c_grid
=
type_convert
<
AccDataType
*>
(
arg
.
p_workspace_
);
const
ADataType
*
p_a_grid
=
arg
.
p_a_grid_
;
const
BDataType
*
p_b_grid
=
arg
.
p_b_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceETensorSizeBytes
()
/
sizeof
(
BDataType
);
p_b_grid
=
type_convert
<
const
BDataType
*>
(
arg
.
p_workspace_
)
+
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
sizeof
(
BDataType
);
}
// nullptr for output, will be set after workspace set
// nullptr for output, will be set after workspace set
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_a_grid_
,
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_b_grid_
,
p_a_grid
,
p_b_grid
,
p_c_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
arg
.
k_batch_
};
p_c_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
arg
.
k_batch_
};
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
...
@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
Number
<
0
>
{})
/
gemm_arg
.
KBatch
;
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
Number
<
0
>
{})
/
gemm_arg
.
KBatch
;
const
auto
clear_workspace
=
[
&
]()
{
const
auto
clear_workspace
=
[
&
]()
{
hip_check_error
(
hipMemsetAsync
(
hip_check_error
(
hipMemsetAsync
(
gemm_arg
.
p_c_grid
,
gemm_arg
.
p_c_grid
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
0
,
arg
.
GetWorkspaceETensorSizeBytes
(),
stream_config
.
stream_id_
));
};
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
...
@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
float
avg_time
=
0.
f
;
auto
launch_elementwise_kernel
=
[
&
]()
{
auto
launch_elementwise_kernel
=
[
&
]()
{
const
AccDataType
*
p_c_grid
=
type_convert
<
const
AccDataType
*>
(
arg
.
p_workspace_
);
const
AccDataType
*
p_c_grid
=
type_convert
<
const
AccDataType
*>
(
arg
.
p_workspace_
);
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_
.
CalculateGridSize
(
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_
.
CalculateGridSize
(
...
@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
std
::
array
<
index_t
,
I1
>
in_out_batch_strides
=
{
std
::
array
<
index_t
,
I1
>
in_out_batch_strides
=
{
static_cast
<
index_t
>
(
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideC_
)};
static_cast
<
index_t
>
(
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideC_
)};
const
auto
kernel
=
kernel_batched_elementwise
<
GridwiseElementwise
,
const
auto
kernel
=
kernel_batched_elementwise
<
GridwiseElementwise
Cast
,
ck
::
Tuple
<
CElementwiseGridDesc_M_N
>
,
ck
::
Tuple
<
CElementwiseGridDesc_M_N
>
,
ck
::
Tuple
<
CElementwiseGridDesc_M_N
>
,
ck
::
Tuple
<
CElementwiseGridDesc_M_N
>
,
ck
::
Tuple
<
const
AccDataType
*>
,
ck
::
Tuple
<
const
AccDataType
*>
,
...
@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
in_out_batch_strides
);
in_out_batch_strides
);
};
};
float
avg_time
=
RunGemmV3
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
const
index_t
grid_size_a
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
const
index_t
grid_size_b
=
arg
.
elementwise_block_2_ctile_map_transpose_b_
.
CalculateGridSize
(
arg
.
b_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceETensorSizeBytes
()
/
sizeof
(
BDataType
);
BDataType
*
p_b_out_grid
=
type_convert
<
BDataType
*>
(
arg
.
p_workspace_
)
+
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
sizeof
(
BDataType
);
auto
kernel_transpose
=
kernel_elementwise_dual
<
GridwiseElementwiseTranspose
,
ck
::
Tuple
<
InputTransposeDescType
>
,
ck
::
Tuple
<
InputTransposeDescType
>
,
ck
::
Tuple
<
OutputTransposeDescType
>
,
ck
::
Tuple
<
OutputTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
BDataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size_a
+
grid_size_b
),
dim3
(
BlockSize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
b_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
b_out_transpose_desc_
),
make_tuple
(
arg
.
p_a_grid_
),
make_tuple
(
arg
.
p_b_grid_
),
make_tuple
(
p_a_out_grid
),
make_tuple
(
p_b_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
arg
.
elementwise_block_2_ctile_map_transpose_b_
,
element_wise
::
PassThrough
{},
grid_size_a
);
}
avg_time
+=
RunGemmV3
(
arg
,
stream_config
);
avg_time
+=
launch_elementwise_kernel
();
avg_time
+=
launch_elementwise_kernel
();
return
avg_time
;
return
avg_time
;
}
}
...
@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
{
return
false
;
return
false
;
}
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
2
)
{
{
if
constexpr
(
!
is_GNWK_GKXC_GNWC
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
(
is_NHWGC_GKYXC_NHWGK
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
{
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
if
constexpr
(
!
(
is_NHWGK_GKYXC_NHWGC
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHWK_GKYXC_GNHWC
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_
G
NDHW
K
_GKZYXC_
G
NDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_N
GC
DHW_GKZYXC_N
GK
DHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return
false
;
return
false
;
}
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
if
((
arg
.
Conv_G_
*
arg
.
Conv_C_
)
%
TransposeTransferDstScalarPerVector
!=
0
)
{
return
false
;
}
if
((
arg
.
Conv_G_
*
arg
.
Conv_K_
)
%
TransposeTransferDstScalarPerVector
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
input_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
TransposeTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
TransposeTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
return
true
;
return
true
;
}
}
...
@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
NumGroupsToMerge
<<
NumGroupsToMerge
;
<<
">"
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
str
<<
", TransposeTransferSrcScalarPerVector: "
<<
TransposeTransferSrcScalarPerVector
<<
", "
<<
"TransposeTransferDstScalarPerVector: "
<<
TransposeTransferDstScalarPerVector
;
}
str
<<
">"
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
d71189ff
...
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
...
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return
false
;
return
false
;
}
}
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
d71189ff
...
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
...
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
}
if
constexpr
(
NDimSpatial
==
1
)
if
constexpr
(
NDimSpatial
==
1
)
{
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
else
if
constexpr
(
NDimSpatial
==
2
)
{
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
else
if
constexpr
(
NDimSpatial
==
3
)
{
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
d71189ff
...
@@ -102,10 +102,9 @@ __global__ void
...
@@ -102,10 +102,9 @@ __global__ void
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
long_index_t
e_group_offset
=
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
auto
&
ds_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
auto
&
ds_
group
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_n_offset
=
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
...
@@ -118,14 +117,14 @@ __global__ void
...
@@ -118,14 +117,14 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
group
_offset
[
i
];
});
if
constexpr
(
isMultiA
||
isMultiB
)
if
constexpr
(
isMultiA
||
isMultiB
)
{
{
AsPointer
p_as_grid_grp
;
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
const
auto
&
as_
group
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// in case of MultiA is false but isMultiB is true
...
@@ -136,27 +135,27 @@ __global__ void
...
@@ -136,27 +135,27 @@ __global__ void
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
as_n_offset
[
i
];
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
as_n_offset
[
i
];
});
});
}
}
else
else
{
{
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
static_for
<
0
,
1
,
1
>
{}(
static_for
<
0
,
1
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
a_n_offset
;
});
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
a_n_offset
;
});
}
}
const
auto
&
bs_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
const
auto
&
bs_
group
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
group
_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_as_grid_grp
,
p_bs_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -169,19 +168,19 @@ __global__ void
...
@@ -169,19 +168,19 @@ __global__ void
}
}
else
else
{
{
const
long_index_t
a_
batch
_offset
=
const
long_index_t
a_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_
batch
_offset
=
const
long_index_t
b_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_
batch
_offset
+
a_n_offset
,
p_as_grid
+
a_
group
_offset
+
a_n_offset
,
p_bs_grid
+
b_
batch
_offset
,
p_bs_grid
+
b_
group
_offset
,
p_ds_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -283,7 +282,8 @@ template <index_t NDimSpatial,
...
@@ -283,7 +282,8 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// in tuple for MultiAB), unpack if tuple was
// passed
// passed
typename
BComputeDataType
=
AComputeDataType
,
typename
BComputeDataType
=
AComputeDataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
index_t
NumGroupsToMerge
=
1
>
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
ALayout
,
...
@@ -302,6 +302,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -302,6 +302,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
static_assert
(
NumGroupsToMerge
>=
1
);
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
...
@@ -318,7 +320,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -318,7 +320,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ConvForwardSpecialization
,
ConvForwardSpecialization
,
true
/*SplitN*/
,
true
/*SplitN*/
,
ADataType
,
ADataType
,
EDataType
>
;
EDataType
,
NumGroupsToMerge
>
;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
@@ -517,7 +520,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -517,7 +520,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
// type is not tuple)
...
@@ -545,7 +549,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -545,7 +549,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
});
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// It is possible that one of the AB is a pointer and one is a tuple.
...
@@ -565,8 +570,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -565,8 +570,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
else
else
{
{
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
// p_as and p_bs are pointers
...
@@ -583,7 +590,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -583,7 +590,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
// D batch stride
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
...
@@ -602,7 +610,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -602,7 +610,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_
(
i
)
=
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
// populate desc for Ds/E
...
@@ -726,7 +734,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -726,7 +734,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdy
=
arg
.
num_group_
;
const
index_t
gdy
=
arg
.
num_group_
/
NumGroupsToMerge
;
const
index_t
gdz
=
num_workgroups_per_Conv_N
;
const
index_t
gdz
=
num_workgroups_per_Conv_N
;
const
auto
K
=
const
auto
K
=
...
@@ -850,6 +858,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -850,6 +858,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
{
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
const
index_t
G
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// check device
// check device
if
(
get_device_name
()
==
"gfx908"
)
if
(
get_device_name
()
==
"gfx908"
)
{
{
...
@@ -898,6 +910,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -898,6 +910,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
(
C
!=
1
)
{
return
false
;
}
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
filter_spatial_dim
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
I3
];
if
(
filter_spatial_dim
!=
I3
)
{
return
false
;
}
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
if
constexpr
(
NumGroupsToMerge
>
1
)
{
if
(
!
(
C
==
1
))
{
return
false
;
}
if
(
G
%
NumGroupsToMerge
!=
0
)
{
return
false
;
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
// check vector access of A
// check vector access of A
// FIXME: layout
// FIXME: layout
...
@@ -907,13 +955,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -907,13 +955,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
{
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
}
}
}
}
}
else
else
{
{
return
false
;
return
false
;
...
@@ -928,8 +981,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -928,8 +981,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
...
@@ -953,8 +1004,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -953,8 +1004,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
{
{
const
index_t
K
=
arg
.
ds_g_n_k_wos_lengths_
[
i
][
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
valid
=
false
;
valid
=
false
;
...
@@ -999,8 +1048,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -999,8 +1048,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
{
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
return
false
;
return
false
;
...
@@ -1298,7 +1345,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -1298,7 +1345,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
NumGroupsToMerge
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment