"vscode:/vscode.git/clone" did not exist on "cd0c1f5709df97364bba9de3d3885d25b2ec4fff"
Unverified Commit 171ed358 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #148 from ROCm/merge_from_public

Merge from public
parents 829e0eb3 e536d321
...@@ -26,7 +26,19 @@ set(version 1.1.0) ...@@ -26,7 +26,19 @@ set(version 1.1.0)
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
include(CTest) include(CTest)
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) if(NOT CK_USE_ALTERNATIVE_PYTHON)
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
else()
message("Using alternative python version")
set(EXTRA_PYTHON_PATH)
string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}")
message("alternative python path is: ${EXTRA_PYTHON_PATH}")
find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED)
add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}")
set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
set(PYTHON_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}")
set(ENV{LD_LIBRARY_PATH} "${EXTRA_PYTHON_PATH}/lib:$ENV{LD_LIBRARY_PATH}")
endif()
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
......
...@@ -262,10 +262,19 @@ def cmake_build(Map conf=[:]){ ...@@ -262,10 +262,19 @@ def cmake_build(Map conf=[:]){
// reduce parallelism when compiling, clang uses too much memory // reduce parallelism when compiling, clang uses too much memory
def nt = nthreads() def nt = nthreads()
def cmd def cmd
def setup_cmd
def build_cmd
def execute_cmd = conf.get("execute_cmd", "") def execute_cmd = conf.get("execute_cmd", "")
if(!setup_args.contains("NO_CK_BUILD")){ if(!setup_args.contains("NO_CK_BUILD")){
def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") echo "running ninja build trace"
setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake -G Ninja ${setup_args} .. ")
build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}")
}
else{
setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}")
}
cmd = conf.get("cmd", """ cmd = conf.get("cmd", """
${setup_cmd} ${setup_cmd}
${build_cmd} ${build_cmd}
...@@ -281,7 +290,19 @@ def cmake_build(Map conf=[:]){ ...@@ -281,7 +290,19 @@ def cmake_build(Map conf=[:]){
echo cmd echo cmd
dir("build"){ dir("build"){
//build CK
sh cmd sh cmd
//run tests
if(!setup_args.contains("NO_CK_BUILD")){
if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){
sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json"
archiveArtifacts "ck_build_trace.json"
sh "ninja test"
}
else{
sh "make check"
}
}
} }
// Only archive from master or develop // Only archive from master or develop
...@@ -543,7 +564,7 @@ def Build_CK(Map conf=[:]){ ...@@ -543,7 +564,7 @@ def Build_CK(Map conf=[:]){
cmake_build(conf) cmake_build(conf)
dir("build"){ dir("build"){
//run tests and examples //run tests and examples
sh 'make -j check' //sh 'make -j check'
if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){ if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){
//we only need the ckProfiler to run the performance tests, so we pack and stash it //we only need the ckProfiler to run the performance tests, so we pack and stash it
//do not stash profiler on nodes where we don't need to run performance tests //do not stash profiler on nodes where we don't need to run performance tests
...@@ -755,7 +776,10 @@ pipeline { ...@@ -755,7 +776,10 @@ pipeline {
name: "BUILD_GFX12", name: "BUILD_GFX12",
defaultValue: false, defaultValue: false,
description: "Build CK and run tests on gfx12 (default: OFF)") description: "Build CK and run tests on gfx12 (default: OFF)")
booleanParam(
name: "NINJA_BUILD_TRACE",
defaultValue: false,
description: "Generate a ninja build trace (default: OFF)")
} }
environment{ environment{
dbuser = "${dbuser}" dbuser = "${dbuser}"
...@@ -789,6 +813,7 @@ pipeline { ...@@ -789,6 +813,7 @@ pipeline {
} }
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
environment{ environment{
setup_args = "NO_CK_BUILD"
execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \ -o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \ -o -not -path \'*.git*\' -iname \'*.cpp\' \
...@@ -805,7 +830,7 @@ pipeline { ...@@ -805,7 +830,7 @@ pipeline {
--file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log" --file-filter=*.cpp --force --enable=all --output-file=ck_cppcheck.log"
} }
steps{ steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
archiveArtifacts "build/ck_cppcheck.log" archiveArtifacts "build/ck_cppcheck.log"
cleanWs() cleanWs()
} }
...@@ -817,6 +842,7 @@ pipeline { ...@@ -817,6 +842,7 @@ pipeline {
} }
agent{ label rocmnode("nogpu") } agent{ label rocmnode("nogpu") }
environment{ environment{
setup_args = "NO_CK_BUILD"
execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \
-o -not -path \'*.git*\' -iname \'*.hpp\' \ -o -not -path \'*.git*\' -iname \'*.hpp\' \
-o -not -path \'*.git*\' -iname \'*.cpp\' \ -o -not -path \'*.git*\' -iname \'*.cpp\' \
...@@ -828,7 +854,7 @@ pipeline { ...@@ -828,7 +854,7 @@ pipeline {
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
} }
steps{ steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
cleanWs() cleanWs()
} }
} }
...@@ -957,10 +983,10 @@ pipeline { ...@@ -957,10 +983,10 @@ pipeline {
} }
agent{ label rocmnode("gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1100;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx1100;gfx90a" \ -DGPU_TARGETS="gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
} }
...@@ -1064,7 +1090,7 @@ pipeline { ...@@ -1064,7 +1090,7 @@ pipeline {
options { retry(1) } options { retry(1) }
agent{ label rocmnode("gfx90a")} agent{ label rocmnode("gfx90a")}
environment{ environment{
setup_args = """ -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On """ setup_args = "NO_CK_BUILD"
} }
steps{ steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
......
# generate a list of kernels, but not actually emit files at config stage # validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
endif()
endforeach()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
) )
execute_process( execute_process(
...@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) ...@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
add_custom_command( add_custom_command(
...@@ -60,6 +80,20 @@ else() ...@@ -60,6 +80,20 @@ else()
endif() endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
endif()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
endif()
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
......
...@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = { ...@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
} }
ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
}
MODE_MAP = { MODE_MAP = {
"batch" : "false", "batch" : "false",
"group" : "true" "group" : "true"
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.ops.fmha_fwd import (
FmhaFwdApiTrait,
DTYPE_BITS,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
FMHA_FWD_APPENDKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API="""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
bs : int # tile size along q seqlen
bsk : int # tile size along k seqlen
bd : int # tile size along qk gemm unroll
bdv : int # tile size along kv gemm unroll
vlayout : str
spad : str
skpad : str
dpad : str
dvpad : str
rope : str # key from ROPE_MAP
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
else : return f'a.seqlen_q % {self.bs} == 0'
@property
def skcheck(self) -> str:
# we do not check all the values in a.seqlen_k_ptr
return 'true'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bd} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bdv} == 0'
@dataclass
class FmhaFwdAppendKVPipeline:
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # key from ROPE_MAP
F_pagedkv : str # t/f
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope != 'no': n += f'_{self.F_rope}'
if self.F_pagedkv == 't': n += '_pagedkv'
return n
class FmhaFwdAppendKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdAppendKVTileSize:
F_bs : int # tile size along q seqlen
F_bsk : int # tile size along k seqlen
F_bd : int # tile size along qk gemm unroll
F_bdv : int # tile size along kv gemm unroll
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdAppendKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaFwdAppendKVTileSize
F_pipeline : FmhaFwdAppendKVPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
F_bdv = self.F_tile.F_bdv,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
return FmhaFwdAppendKVApiTrait(
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
}
else:
return None
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_appendkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
\ No newline at end of file
...@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import ( ...@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
) )
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
...@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, ...@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias}, {F_bias},
false, false,
{F_lse}, {F_lse},
{F_dropout},
{F_squant}, {F_squant},
{F_pagedkv},
kHasUnevenSplits, kHasUnevenSplits,
{F_occupancy}>; {F_occupancy}>;
...@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< ...@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
...@@ -86,7 +93,7 @@ using fmha_kernel = ...@@ -86,7 +93,7 @@ using fmha_kernel =
fmha_pipeline, fmha_pipeline,
fmha_epilogue>; fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
using k_ = fmha_kernel; using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
...@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a) ...@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}}; }};
}} }}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream> #include <iostream>
template<> template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a) void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if constexpr({F_mode} == false) {{ // batch mode if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ // we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a); kernel_runner<false>::run(s, a);
}} else {{ }} else {{
kernel_runner<true>::run(s, a); kernel_runner<true>::run(s, a);
...@@ -160,7 +172,7 @@ using fmha_kernel = ...@@ -160,7 +172,7 @@ using fmha_kernel =
fmha_pipeline, fmha_pipeline,
fmha_epilogue>; fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
using k_ = fmha_kernel; using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
...@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m ...@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream> #include <iostream>
template<> template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a) void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if (a.num_splits <= 16) {{ if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a); kernel_runner<4>::run(s, a);
...@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API=""" ...@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream> #include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_> template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout std::cout
...@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) ...@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
); );
}} }}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float r = -1; float r = -1;
{F_dispatch} {F_dispatch}
return r; return r;
}} }}
""" """
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}} }}
""" """
@dataclass
class FmhaFwdSplitKVApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0blen : int
vlayout : str
mask : str
bias : str #
lse : str #
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@dataclass @dataclass
class FmhaFwdSplitKVPipeline: class FmhaFwdSplitKVPipeline:
tag : str tag : str
...@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline: ...@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
F_dvpad : str # F_dvpad : str #
F_bias : str # true/false F_bias : str # true/false
F_lse : str # F_lse : str #
F_dropout : str #
F_squant : str # F_squant : str #
F_pagedkv : str # t/f
F_mask : str # value from MASK_MAP F_mask : str # value from MASK_MAP
@property @property
...@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline: ...@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
else: else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse' if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout'
if self.F_squant == 't' : n += '_squant' if self.F_squant == 't' : n += '_squant'
if self.F_pagedkv == 't' : n += '_pagedkv'
return n return n
@dataclass @dataclass
...@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool: ...@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
self.pool = dict() self.pool = dict()
self.mask_impl = mask_impl self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None: def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None:
# TODO: do we need to check duplication? # TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys(): if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict() self.pool[trait.dtype] = dict()
...@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool: ...@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
...@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel: ...@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy, F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
...@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel: ...@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
def filename(self) -> str: def filename(self) -> str:
return self.name + ".cpp" return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait: def api_trait(self) -> FmhaFwdSplitKVApiTrait:
return FmhaFwdApiTrait( return FmhaFwdSplitKVApiTrait(
pipeline_tag=self.F_pipeline.tag, pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim), hdim=str(self.F_hdim),
dtype=self.F_dtype, dtype=self.F_dtype,
...@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel: ...@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse, lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant, squant=self.F_pipeline.F_squant,
pagedkv=self.F_pipeline.F_pagedkv,
spad=self.F_pipeline.F_spad, spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad, skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad, dpad=self.F_pipeline.F_dpad,
...@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
def filename(self) -> str: def filename(self) -> str:
return self.name + ".cpp" return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
# TODO: design a more practical way to do it # TODO: design a more practical way to do it
# this is current supported tile size per hdim # this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
...@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f' squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
# splitkv kernel donot support dropout for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): # TODO: use async pipeline when compiler is more stable
if hdim == 256: if hdim == 256 or hdim in [32, 64, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
else: else:
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
if receipt == 1: if receipt == 1:
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels # no need lse/paged-kv kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask))
else: else:
assert False assert False
return pipelines return pipelines
...@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't': if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
if pipeline.F_pagedkv == 't':
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
F_hdim=hdim, F_hdim=hdim,
F_dtype=dtype, F_dtype=dtype,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "fmha_fwd.hpp" #include "fmha_fwd.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp" #include "utils.hpp"
#include <array> #include <array>
...@@ -16,6 +17,10 @@ ...@@ -16,6 +17,10 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
#error "we should enable fmha_fwd_splitkv() api in order to cooperate with fmha_fwd_appendkv()"
#endif
template <typename T> template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{ {
...@@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[]) ...@@ -50,7 +55,11 @@ auto create_args(int argc, char* argv[])
"seqlen_q. if group-mode, means the average value of seqlen_q\n" "seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n" "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)") "also with \"-s=s0,s1,s2...\" comma seperated int to set per batch seqlen(group-mode)")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s") .insert("s_k", "-1", "seqlen_k (including new key/value), -1 means equal to s")
.insert("s_knew",
"0",
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly.")
.insert("s_kpad", .insert("s_kpad",
"-1", "-1",
"seqlen_k stride between 2 tokens, currently used in group-mode only\n" "seqlen_k stride between 2 tokens, currently used in group-mode only\n"
...@@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[]) ...@@ -114,9 +123,14 @@ auto create_args(int argc, char* argv[])
.insert("drop_seed", "1", "seed for random number generator") .insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator") .insert("drop_offset", "0", "offset for random number generator")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert(
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
.insert("num_splits", .insert("num_splits",
"1", "1",
"# of splits for key/value. 0 to determine actual number by heuristic") "# of splits for key/value. 0 to determine actual number by heuristic")
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel"); .insert("repeat", "20", "number of iterations to benchmark the kernel");
...@@ -244,20 +258,6 @@ int override_num_splits_if_necessary( ...@@ -244,20 +258,6 @@ int override_num_splits_if_necessary(
return num_splits; return num_splits;
} }
float fmha_fwd_dispatch(fmha_fwd_traits traits,
fmha_fwd_args args,
const ck_tile::stream_config& config)
{
if(1 < args.num_splits)
{
return fmha_fwd_splitkv(traits, args, config);
}
else
{
return fmha_fwd(traits, args, config);
}
}
template <typename DataType> template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
...@@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -276,11 +276,114 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
auto [seqlen_qs, seqlen_ks, seqlen_kpads] = decode_seqlen(mode, std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
#if !CK_TILE_FMHA_FWD_APPENDKV_API
if(seqlen_knew != 0)
{
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
seqlen_knew = 0;
}
#endif
if(seqlen_knew < 0)
{
seqlen_knew = randint<ck_tile::index_t>(1, arg_parser.get_int("s"), seed);
}
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
std::is_same_v<DataType, ck_tile::bf16_t>))
{
if(0 < rotary_dim)
{
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
return false;
}
}
#if !CK_TILE_FMHA_FWD_APPENDKV_API
else if(0 < rotary_dim)
{
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
<< std::endl;
rotary_dim = 0;
}
#endif
if(!(rotary_dim <= hdim_q))
{
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
return false;
}
else if(!(rotary_dim % 16 == 0))
{
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
return false;
}
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size)
{
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
<< std::endl;
page_block_size = 0;
}
#endif
if(!(page_block_size % 128 == 0))
{
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
<< std::endl;
return false;
}
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
if(use_cache_batch_idx)
{
std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
#endif
if(0 < page_block_size && use_cache_batch_idx)
{
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
"'cache_batch_idx' option"
<< std::endl;
use_cache_batch_idx = false;
}
// the input tensor layout for kvcache is same as batch mode
const bool need_append_kvcache = (0 < seqlen_knew || 0 < rotary_dim);
const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size);
if(use_kvcache && mode != mode_enum::batch)
{
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
}
auto [seqlen_qs, seqlen_ks, seqlen_kpads] =
decode_seqlen(mode,
batch, batch,
arg_parser.get_str("s"), arg_parser.get_str("s"),
arg_parser.get_str("s_k"), arg_parser.get_str("s_k"),
arg_parser.get_str("s_kpad")); arg_parser.get_str("s_kpad"),
/*seqlen_k_min=*/0 < seqlen_knew ? seqlen_knew : 0,
use_kvcache);
// compute kvcache seqlen_k (before appending knew/vnew)
auto cache_seqlen_ks = seqlen_ks;
std::transform(cache_seqlen_ks.begin(),
cache_seqlen_ks.end(),
cache_seqlen_ks.begin(),
[&](auto seqlen_k) { return seqlen_k - seqlen_knew; });
#if 0 #if 0
// clang-format off // clang-format off
...@@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -290,11 +393,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on // clang-format on
#endif #endif
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
...@@ -357,13 +455,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -357,13 +455,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::string init_method = arg_parser.get_str("init"); std::string init_method = arg_parser.get_str("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0) const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if(num_splits != 1)
{ {
seed.reset(); std::cerr << "split-kv is not supported. ignoring the 'num_splits' option" << std::endl;
num_splits = 1;
} }
#endif
int num_splits = arg_parser.get_int("num_splits");
int stream_warmup = arg_parser.get_int("warmup"); int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat"); int stream_repeat = arg_parser.get_int("repeat");
...@@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -425,6 +527,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
} }
const ck_tile::index_t max_num_page_blocks =
(0 < page_block_size
? batch * std::max(1, ck_tile::integer_divide_ceil(max_seqlen_k, page_block_size))
: 0);
// legalize num_splits according to other options // legalize num_splits according to other options
if(num_splits < 1) if(num_splits < 1)
{ {
...@@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -436,6 +543,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cerr << "num_splits greater than 128 is not supported" << std::endl; std::cerr << "num_splits greater than 128 is not supported" << std::endl;
return false; return false;
} }
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < p_drop && (1 < num_splits || use_kvcache))
{
std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
<< std::endl;
p_drop = 0.0f;
}
#endif
auto get_lengths = [&](bool permute, auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/, ck_tile::index_t b /*batch*/,
...@@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -462,11 +577,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QDataType> q_host( ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host( ck_tile::HostTensor<KDataType> k_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); 0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host(
0 < seqlen_knew
? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_q)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<VDataType> v_host( ck_tile::HostTensor<VDataType> v_host(
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) 0 < page_block_size
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); ? (is_v_rowmajor
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_v)
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
ck_tile::HostTensor<VDataType> vnew_host(
0 < seqlen_knew
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
: get_lengths(i_perm, batch, nhead_k, hdim_v, seqlen_knew))
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<BiasDataType> bias_host( ck_tile::HostTensor<BiasDataType> bias_host(
bias.type == bias_enum::elementwise_bias bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
...@@ -478,12 +608,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -478,12 +608,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 2>{batch, nhead}) : std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1}); : std::array<ck_tile::index_t, 2>{1, 1});
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed);
ck_tile::HostTensor<LSEDataType> lse_acc_host( ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits 1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q} ? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host( ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits 1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v} ? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1}); : std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
...@@ -500,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -500,39 +633,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1}); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<int32_t> block_table_host(
0 < page_block_size ? std::array<ck_tile::index_t, 2>{batch, max_num_page_blocks / batch}
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
? std::array<ck_tile::index_t, 1>{batch}
: std::array<ck_tile::index_t, 1>{1});
if(init_method == "ui" || init_method == "0") if(init_method == "ui" || init_method == "0")
{ {
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host); ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host); ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host); ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host); ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
} }
else if(init_method == "ni") else if(init_method == "ni")
{ {
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host); ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host); ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistributionIntegerValue<KDataType>{-3.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host); ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistributionIntegerValue<VDataType>{-3.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host); ck_tile::FillNormalDistributionIntegerValue<BiasDataType>{-3.f, 3.f, seed}(bias_host);
} }
else if(init_method == "uf" || init_method == "1") else if(init_method == "uf" || init_method == "1")
{ {
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host); ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
} }
else if(init_method == "nf") else if(init_method == "nf")
{ {
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host); ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host); ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(k_host);
ck_tile::FillNormalDistribution<KDataType>{0.f, 3.f, seed}(knew_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host); ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(v_host);
ck_tile::FillNormalDistribution<VDataType>{0.f, 3.f, seed}(vnew_host);
ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host); ck_tile::FillNormalDistribution<BiasDataType>{0.f, 3.f, seed}(bias_host);
} }
else if(init_method == "tf" || init_method == "2") else if(init_method == "tf" || init_method == "2")
{ {
ck_tile::FillTrigValue<QDataType>{}(q_host); ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host); ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<KDataType>{}(knew_host);
ck_tile::FillTrigValue<VDataType>{}(v_host); ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host); ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
} }
else if(init_method == "ufq" || init_method == "uf:q" || else if(init_method == "ufq" || init_method == "uf:q" ||
...@@ -540,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -540,7 +691,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host); ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host);
// bias_fp8 = qscale_bias * bias_fp32 // bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);
...@@ -550,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -550,7 +703,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(bias.type == bias_enum::alibi) if(bias.type == bias_enum::alibi)
{ {
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead); auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
assert(slopes.size() == nhead); assert(slopes.size() == static_cast<std::size_t>(nhead));
if(bias.rank_info == 0) if(bias.rank_info == 0)
{ {
// alibi in 1*h // alibi in 1*h
...@@ -565,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -565,10 +718,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
} }
} }
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
...@@ -576,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -576,27 +733,41 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf(seqlen_kpads[0] < 0 ? 0 : seqlen_ks.size() * sizeof(int32_t)); ck_tile::DeviceMem seqlen_k_buf(
use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data()); k_buf.ToDevice(k_host.data());
knew_buf.ToDevice(knew_host.data());
v_buf.ToDevice(v_host.data()); v_buf.ToDevice(v_host.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data()); bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data()); : seqstart_k_with_padding_host.data());
seqlen_k_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr : seqlen_ks.data()); seqlen_k_buf.ToDevice(use_kvcache || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr);
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data()); alibi_slope_buf.ToDevice(alibi_slope_host.data());
block_table_buf.ToDevice(block_table_host.data());
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
// clang-format off // clang-format off
auto layout_str = [&](bool permute){ auto layout_str = [&](bool permute){
if (permute) return std::string("bhsd"); if(permute) return std::string("bhsd");
else return std::string("bshd"); else return std::string("bshd");
}; };
auto io_layout = [&](bool iperm_, bool operm_) { auto io_layout = [&](bool iperm_, bool operm_) {
if (iperm_ == operm_) return layout_str(iperm_); if(iperm_ == operm_) return layout_str(iperm_);
else return layout_str(iperm_) + std::string("-") + layout_str(operm_); else return layout_str(iperm_) + std::string("-") + layout_str(operm_);
}; };
// clang-format on // clang-format on
...@@ -609,39 +780,57 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -609,39 +780,57 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
<< ", mask:" << mask << ", v:" << vlayout; << ", mask:" << mask << ", v:" << vlayout;
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(0 < rotary_dim)
{
std::cout << ", rotary_dim:" << rotary_dim << "("
<< (is_rotary_interleaved ? "inter" : "half") << ")";
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits) if(1 < num_splits)
{ {
std::cout << ", num_splits:" << num_splits; std::cout << ", num_splits:" << num_splits;
} }
if(0 < page_block_size)
{
std::cout << ", page_block_size:" << page_block_size;
}
if(use_cache_batch_idx)
{
std::cout << ", cache_batch_idx:" << use_cache_batch_idx;
}
#endif
std::cout << std::flush; std::cout << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q, const auto init_traits = [&](auto& traits) {
hdim_v, traits.hdim_q = hdim_q;
data_type, traits.hdim_v = hdim_v;
mode == mode_enum::group, traits.data_type = data_type;
is_v_rowmajor, traits.is_v_rowmajor = is_v_rowmajor;
mask.type,
bias.type,
lse,
p_drop > 0.0f,
squant};
auto p_compute_element_func = [&]() { if constexpr(std::is_same_v<fmha_fwd_appendkv_traits, std::decay_t<decltype(traits)>>)
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) {
return ck_tile::scales{scale_p}; traits.rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
else : rope_enum::half_rotated)
return ck_tile::identity{}; : rope_enum::none);
}(); }
else // fmha_fwd_traits or fmha_splitkv_traits
{
traits.is_group_mode = (mode == mode_enum::group);
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
traits.do_fp8_static_quant = squant;
auto oacc_element_func = [&]() { if constexpr(std::is_same_v<fmha_fwd_traits, std::decay_t<decltype(traits)>>)
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) {
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{}, traits.has_dropout = (p_drop > 0.0f);
ck_tile::scales{scale_o}); }
else }
return ck_tile::identity{}; };
}();
auto fmha_args = [&, k_paddings_ = seqlen_kpads]() { const auto init_args = [&, k_paddings_ = seqlen_kpads](auto& args) {
assert(nhead % nhead_k == 0); assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & /// seqlen_k] in this example, hence both the 'batch_stride_bias' &
...@@ -649,11 +838,19 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -649,11 +838,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
// setup stride_* arguments // setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_knew = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() { const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v; return i_perm ? hdim_v : nhead_k * hdim_v;
else else
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; return 0 < page_block_size ? (i_perm ? page_block_size : nhead_k * page_block_size)
: (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? seqlen_knew : nhead_k * seqlen_knew;
}(); }();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k);
...@@ -661,12 +858,23 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -661,12 +858,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments // setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); const ck_tile::index_t nhead_stride_k =
(0 < page_block_size ? (i_perm ? page_block_size * hdim_q : hdim_q)
: (i_perm ? shape_seqlen_k * hdim_q : hdim_q));
const ck_tile::index_t nhead_stride_knew = (i_perm ? seqlen_knew * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() { const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor) if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v; return 0 < page_block_size ? (i_perm ? page_block_size * hdim_v : hdim_v)
: (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
else else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; return 0 < page_block_size ? (i_perm ? hdim_v * page_block_size : page_block_size)
: (i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k);
}();
const ck_tile::index_t nhead_stride_vnew = [&]() {
if(is_v_rowmajor)
return i_perm ? seqlen_knew * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_knew : seqlen_knew;
}(); }();
const ck_tile::index_t nhead_stride_bias = const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
...@@ -677,87 +885,193 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -677,87 +885,193 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments // setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); const ck_tile::index_t batch_stride_k =
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); (0 < page_block_size ? (nhead_k * page_block_size * hdim_q)
: (nhead_k * shape_seqlen_k * hdim_q));
const ck_tile::index_t batch_stride_knew = (nhead_k * seqlen_knew * hdim_q);
const ck_tile::index_t batch_stride_v =
(0 < page_block_size ? (nhead_k * hdim_v * page_block_size)
: (nhead_k * hdim_v * shape_seqlen_k));
const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel) // setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q); const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v); const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(), args.q_ptr = q_buf.GetDeviceBuffer();
k_buf.GetDeviceBuffer(), args.k_ptr = k_buf.GetDeviceBuffer();
v_buf.GetDeviceBuffer(), args.v_ptr = v_buf.GetDeviceBuffer();
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(), args.batch = batch;
randval_buf.GetDeviceBuffer(), args.seqlen_q = shape_seqlen_q; // unused in group mode
lse_acc_buf.GetDeviceBuffer(), args.hdim_q = hdim_q;
o_acc_buf.GetDeviceBuffer(), args.hdim_v = hdim_v;
lse_buf.GetDeviceBuffer(), args.nhead_q = nhead;
o_buf.GetDeviceBuffer(), args.nhead_k = nhead_k;
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(), args.stride_q = stride_q;
k_paddings_[0] < 0 ? nullptr : seqlen_k_buf.GetDeviceBuffer(), args.stride_k = stride_k;
shape_seqlen_q, args.stride_v = stride_v;
shape_seqlen_k, args.nhead_stride_q = nhead_stride_q;
batch, args.nhead_stride_k = nhead_stride_k;
max_seqlen_q, args.nhead_stride_v = nhead_stride_v;
hdim_q, args.batch_stride_q = batch_stride_q;
hdim_v, args.batch_stride_k = batch_stride_k;
nhead, args.batch_stride_v = batch_stride_v;
nhead_k,
num_splits, if constexpr(std::is_same_v<fmha_fwd_appendkv_args, std::decay_t<decltype(args)>>)
scale_s, {
scale_p, args.knew_ptr = knew_buf.GetDeviceBuffer();
scale_o, args.vnew_ptr = vnew_buf.GetDeviceBuffer();
stride_q, args.seqlen_knew = seqlen_knew;
stride_k,
stride_v, args.seqlen_k_ptr = cache_seqlen_k_buf.GetDeviceBuffer();
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias, args.rotary_cos_ptr = (0 < rotary_dim ? rotary_cos_buf.GetDeviceBuffer() : nullptr);
stride_randval, args.rotary_sin_ptr = (0 < rotary_dim ? rotary_sin_buf.GetDeviceBuffer() : nullptr);
stride_o_acc, args.rotary_dim = rotary_dim;
stride_o, args.has_mask = (mask.type != mask_enum::no_mask);
nhead_stride_q,
nhead_stride_k, args.block_table_ptr =
nhead_stride_v, (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
nhead_stride_bias, args.batch_stride_block_table = batch_stride_block_table;
nhead_stride_randval, args.page_block_size = page_block_size;
nhead_stride_lse,
nhead_stride_lse_acc, args.cache_batch_idx =
nhead_stride_o_acc, (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
nhead_stride_o,
batch_stride_q, args.stride_knew = stride_knew;
batch_stride_k, args.stride_vnew = stride_vnew;
batch_stride_v, args.nhead_stride_knew = nhead_stride_knew;
batch_stride_bias, args.nhead_stride_vnew = nhead_stride_vnew;
batch_stride_randval, args.batch_stride_knew = batch_stride_knew;
batch_stride_lse, args.batch_stride_vnew = batch_stride_vnew;
batch_stride_lse_acc, }
batch_stride_o_acc, else // fmha_fwd_args or fmha_fwd_splitkv_args
batch_stride_o, {
split_stride_lse_acc, args.bias_ptr = bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
split_stride_o_acc, : bias_buf.GetDeviceBuffer();
mask.left, args.lse_ptr = lse_buf.GetDeviceBuffer();
mask.right, args.o_ptr = o_buf.GetDeviceBuffer();
static_cast<ck_tile::index_t>(mask.type),
p_drop, args.seqstart_q_ptr =
s_randval, (mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
{drop_seed, drop_offset}}; args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr =
(use_kvcache || 0 <= k_paddings_[0] ? seqlen_k_buf.GetDeviceBuffer() : nullptr);
args.seqlen_k = shape_seqlen_k; // unused in group mode (or kvcache enabled)
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.scale_p = scale_p;
args.scale_o = scale_o;
args.stride_bias =
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
args.stride_o = stride_o;
args.nhead_stride_bias = nhead_stride_bias;
args.nhead_stride_lse = nhead_stride_lse;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_bias = batch_stride_bias;
args.batch_stride_lse = batch_stride_lse;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
if constexpr(std::is_same_v<fmha_fwd_args, std::decay_t<decltype(args)>>)
{
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
args.stride_randval = stride_randval;
args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval;
args.p_drop = p_drop;
args.s_randval = s_randval;
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();
args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
args.num_splits = num_splits;
args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
}
}
};
const float appendkv_ave_time = [&] {
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(need_append_kvcache)
{
fmha_fwd_appendkv_traits fwd_appendkv_traits;
init_traits(fwd_appendkv_traits);
fmha_fwd_appendkv_args fwd_appendkv_args;
init_args(fwd_appendkv_args);
return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config);
}
#endif
return 0.0f;
}(); }();
float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || use_kvcache)
{
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);
fmha_fwd_splitkv_args fmha_splitkv_args;
init_args(fmha_splitkv_args);
return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config);
}
#endif
fmha_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_fwd_args fmha_args;
init_args(fmha_args);
return fmha_fwd(fmha_traits, fmha_args, stream_config);
}();
if(ave_time < 0) if(appendkv_ave_time < 0.0f || fwd_ave_time < 0.0f)
{ {
std::cout << ", not supported yet" << std::flush << std::endl; std::cout << ", not supported yet" << std::flush << std::endl;
return false; return false;
} }
const float ave_time = (appendkv_ave_time + fwd_ave_time);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
...@@ -775,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -775,36 +1089,46 @@ bool run(const ck_tile::ArgParser& arg_parser)
o_buf.FromDevice(o_host.data()); o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data()); lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data()); randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::scales{scale_p};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else
return ck_tile::identity{};
}();
float p_undrop = 1.0 - p_drop; float p_undrop = 1.0 - p_drop;
uint8_t p_undrop_in_uint8_t = uint8_t p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max())); uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
float rp_undrop = 1.0 / p_undrop; float rp_undrop = 1.0 / p_undrop;
bool pass = true; bool pass = true;
for(ck_tile::index_t wb = 0; wb < batch; ++wb) for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{ {
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// adjust matrix index according to the mode // adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t cache_b_idx =
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = const ck_tile::index_t key_offset =
(mode == mode_enum::batch (mode == mode_enum::batch
? 0 ? 0
: (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb])); : (seqlen_kpads[0] < 0 ? seqstart_k_host[wb] : seqstart_k_with_padding_host[wb]));
const auto v_host_ref_lengths =
std::array<ck_tile::index_t, 3>{nhead, hdim_v, real_seqlen_k};
const auto v_host_ref_strides =
is_v_rowmajor
? std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, 1, hdim_v}
: std::array<ck_tile::index_t, 3>{hdim_v * real_seqlen_k, real_seqlen_k, 1};
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q});
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q});
ck_tile::HostTensor<VDataType> v_host_ref(v_host_ref_lengths, v_host_ref_strides); ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k});
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v});
ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); ck_tile::HostTensor<SMPLComputeDataType> s_host_ref({nhead, real_seqlen_q, real_seqlen_k});
...@@ -815,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -815,22 +1139,138 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format off // clang-format off
// permute // permute
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[0], i[1] + query_offset, i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b_idx, i[1] + query_offset, i[0], i[2]); });
#if CK_TILE_FMHA_FWD_APPENDKV_API
// optionally apply RoPE to the q_host_ref
if(0 < rotary_dim)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
ck_tile::reference_batched_rotary_position_embedding(
q_host_ref, rotary_cos_slice, rotary_sin_slice, is_rotary_interleaved, q_host_ref_ro,
/*use_1_row_sin_cos=*/mask.type == mask_enum::no_mask);
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size) {
if(i_perm) {
k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
});
} else {
k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
});
}
} else
#endif
{
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Knew to the end of K
if(0 < seqlen_knew)
{
ck_tile::HostTensor<KDataType> knew_host_ref({nhead, seqlen_knew, hdim_q});
if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[0] / nr, i[1], i[2]); });
else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(wb, i[1], i[0] / nr, i[2]); });
// optionally apply RoPE to the knew_host_ref
auto* real_knew_host_ref = &knew_host_ref;
std::optional<decltype(knew_host_ref)> knew_host_ref_ro;
if(0 < rotary_dim)
{
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); ck_tile::reference_batched_rotary_position_embedding(
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); knew_host_ref,
rotary_cos_slice,
rotary_sin_slice,
is_rotary_interleaved,
knew_host_ref_ro.value());
real_knew_host_ref = &knew_host_ref_ro.value();
}
if (is_v_rowmajor) { (*real_knew_host_ref).ForEach([&](auto& self, auto i) {
k_host_ref(i[0], i[1] + cache_seqlen_ks[wb], i[2]) = self(i);
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size) {
if(is_v_rowmajor) {
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
});
} else {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
});
}
}
else
{
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
});
} else {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[1], i[0] / nr, i[2] % page_block_size);
});
}
}
} else
#endif
{
if(is_v_rowmajor) {
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
} }
else { else
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); {
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); });
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); });
} }
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Vnew to the end of V
if(0 < seqlen_knew)
{
ck_tile::HostTensor<VDataType> vnew_host_ref({nhead, hdim_v, seqlen_knew});
if(is_v_rowmajor)
{
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[2], i[1]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[2], i[0] / nr, i[1]); });
}
else
{
if(i_perm) vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[0] / nr, i[1], i[2]); });
else vnew_host_ref.ForEach([&](auto& self, auto i) { self(i) = vnew_host(wb, i[1], i[0] / nr, i[2]); });
}
vnew_host_ref.ForEach([&](auto& self, auto i) {
v_host_ref(i[0], i[1], i[2] + cache_seqlen_ks[wb]) = self(i);
});
}
#endif
// clang-format on // clang-format on
// reference // reference
...@@ -959,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -959,7 +1399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref( ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); {nhead, real_seqlen_q, real_seqlen_k});
randval_host_ref.ForEach([&](auto& self, auto idx) { randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); self(idx) = randval_host(b_idx, idx[0], idx[1] + query_offset, idx[2]);
}); });
ck_tile::reference_batched_dropout( ck_tile::reference_batched_dropout(
p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); p_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
...@@ -976,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -976,8 +1416,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v}); ck_tile::HostTensor<ODataType> o_host_result({nhead, real_seqlen_q, hdim_v});
// clang-format off // clang-format off
// permute // permute
if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[0], idx[1] + query_offset, idx[2]); });
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method); auto [rtol, atol] = get_elimit<DataType>(init_method);
...@@ -999,7 +1439,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -999,7 +1439,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q}); ck_tile::HostTensor<SMPLComputeDataType> lse_host_result({nhead, real_seqlen_q});
lse_host_result.ForEach([&](auto& self, auto idx) { lse_host_result.ForEach([&](auto& self, auto idx) {
self(idx) = lse_host(b, idx[0], idx[1] + query_offset); self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset);
}); });
cur_pass = ck_tile::check_err(lse_host_result, cur_pass = ck_tile::check_err(lse_host_result,
......
...@@ -5,10 +5,13 @@ ...@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp" #include "ck_tile/ops/fmha.hpp"
#include "bias.hpp" #include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits> #include <type_traits>
template <typename DataType> template <typename DataType>
...@@ -93,13 +96,86 @@ struct fmha_fwd_args ...@@ -93,13 +96,86 @@ struct fmha_fwd_args
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr; void* rand_val_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
struct fmha_fwd_splitkv_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* lse_acc_ptr; void* lse_acc_ptr;
void* o_acc_ptr; void* o_acc_ptr;
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k; ck_tile::index_t seqlen_k;
ck_tile::index_t batch; ck_tile::index_t batch;
...@@ -109,21 +185,21 @@ struct fmha_fwd_args ...@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile::index_t nhead_q; ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k; ck_tile::index_t nhead_k;
ck_tile::index_t num_splits; ck_tile::index_t num_splits;
float scale_s; float scale_s;
float scale_p; float scale_p;
float scale_o; float scale_o;
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o_acc; ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
...@@ -132,19 +208,62 @@ struct fmha_fwd_args ...@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; };
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset; struct fmha_fwd_appendkv_args
{
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
const void* vnew_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_knew;
ck_tile::index_t batch;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
ck_tile::index_t rotary_dim;
bool has_mask;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
ck_tile::index_t stride_v;
ck_tile::index_t stride_vnew;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_knew;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
}; };
template <typename FmhaKernel> template <typename FmhaKernel>
...@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
} }
template <typename Kernel> template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{ {
assert(args.nhead_q % args.nhead_k == 0); assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] { auto kargs = [&] {
...@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr, args.lse_acc_ptr,
args.o_acc_ptr, args.o_acc_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
...@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o_acc, args.stride_o_acc,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type);
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
...@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr, args.lse_acc_ptr,
args.o_acc_ptr, args.o_acc_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.num_splits, args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o_acc, args.stride_o_acc,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse_acc, args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type);
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
}(); }();
...@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
} }
template <typename Kernel> template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{ {
assert(args.nhead_q % args.nhead_k == 0); assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] { auto kargs = [&] {
...@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) ...@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename Kernel>
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.has_mask,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
...@@ -458,8 +615,52 @@ struct fmha_fwd_traits_ ...@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
template <typename Traits_> template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_splitkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_> template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_(); std::string fmha_fwd_splitkv_get_name_();
...@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_ ...@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
}; };
template <typename Traits_> template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_(); std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kTileSizeS_,
ck_tile::index_t kTileSizeSk_,
ck_tile::index_t kTileSizeD_,
ck_tile::index_t kTileSizeDv_,
bool kIsVLayoutRowMajor_,
bool kPadS_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_,
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_fwd_traits struct fmha_fwd_traits
{ {
...@@ -508,4 +743,32 @@ struct fmha_fwd_traits ...@@ -508,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
struct fmha_fwd_splitkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
fmha_fwd_splitkv_args,
const ck_tile::stream_config&);
struct fmha_fwd_appendkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_v_rowmajor;
rope_enum rope_type;
};
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
...@@ -5,25 +5,30 @@ ...@@ -5,25 +5,30 @@
import argparse import argparse
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
import pkgutil
import sys
from typing import List, Optional from typing import List, Optional
import codegen.ops
from codegen.cmake_config import * from codegen.cmake_config import *
from codegen.ops import (
fmha_fwd,
fmha_fwd_splitkv,
fmha_bwd
)
class HandlerId(IntEnum): class HandlerId(IntEnum):
LIST_BLOBS = 0 LIST_BLOBS = 0
WRITE_BLOBS = 1 WRITE_BLOBS = 1
handlers = { # inspect all modules under 'codegen.ops' and register API handlers
'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), ops = []
'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
} if full_module_name not in sys.modules:
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
(op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None: if output_dir is None:
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen,
ck_tile::index_t rotary_dim,
std::optional<unsigned> seed = std::nullopt)
{
// return dummy tensors if we won't apply RoPE at all
if(rotary_dim <= 0)
{
ck_tile::HostTensor<DataType> dummy({1, 1});
return std::make_tuple(dummy, dummy);
}
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
const ck_tile::index_t num_rows = seqlen * 2;
const ck_tile::index_t num_cols = rotary_dim / 2;
using std::begin, std::end;
ck_tile::HostTensor<float> angle({num_rows, num_cols});
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
return std::make_tuple(cos, sin);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
const ck_tile::HostTensor<DataType>& sin,
ck_tile::index_t seqlen_offset,
ck_tile::index_t seqlen)
{
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
return std::make_tuple(cos_pt, sin_pt);
}
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_bwd
VALID=0 VALID=0
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
......
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_fwd
VALID=0 VALID=0
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
......
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_bwd
KNAME=1 KNAME=1
export CK_WARMUP=0 export CK_WARMUP=0
......
#!/bin/sh #!/bin/bash
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_fwd
KNAME=1 KNAME=1
export CK_WARMUP=0 export CK_WARMUP=0
...@@ -10,44 +9,98 @@ export CK_REPEAT=1 ...@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1' COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0 # mode=0
# export HIP_VISIBLE_DEVICES=4 # export HIP_VISIBLE_DEVICES=4
set -x
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done TEST_SPLITKV=0
done TEST_APPENDKV=0
done # options:
done # -s: run splitkv tests
done # -a: run appendkv tests
done while getopts ":sa" opt; do
done case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done done
run_fp16_bf16_tests() {
local NUM_SPLITS=(1)
local PAGE_BLOCK_SIZE=(0)
local CACHE_BATCH_IDX=(0)
if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS+=(2 3)
PAGE_BLOCK_SIZE+=(128)
CACHE_BATCH_IDX+=(1)
fi
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ;
}
run_fp8_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp16_appendkv_tests() {
for s in $(seq 63 1 65) ; do
for s_k in 65 129 ; do
for s_knew in 0 64 $s_k ; do
for hdim in 32 64 128 256 ; do
for ri in 0 1 ; do
for rdim in 0 16 32 $hdim ; do
for page_block_size in 0 128 ; do
for cache_batch_idx in 0 1 ; do
$EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done
}
set -x
run_fp16_bf16_tests
run_fp8_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests
fi
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done
done
done
done
set +x set +x
\ No newline at end of file
...@@ -3,15 +3,17 @@ ...@@ -3,15 +3,17 @@
#pragma once #pragma once
#include <algorithm>
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <functional>
#include <optional> #include <optional>
#include <ostream> #include <ostream>
#include <sstream>
#include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp" #include "ck_tile/core/container/span.hpp"
...@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens) ...@@ -40,13 +42,17 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
std::vector<int32_t> generate_seqlens(mode_enum mode, std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count, unsigned count,
int32_t seqlen_avg, int32_t seqlen_avg,
int32_t seqlen_min = -1, // if not negative, clamp min
int32_t seqlen_max = -1, // if not negative, clamp max int32_t seqlen_max = -1, // if not negative, clamp max
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
assert(0 < count); assert(0 < count);
std::vector<int32_t> seqlens( seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg); seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(mode == mode_enum::group && 1 < count) if(mode == mode_enum::group && 1 < count)
{ {
...@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -62,15 +68,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{ {
const size_type to_decrease = next_idx(); const size_type to_decrease = next_idx();
// make sure each elements of seqlens is always greater than 0 // make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if(seqlens[to_decrease] == 1) if(seqlens[to_decrease] == seqlen_min)
{ {
continue; continue;
} }
const size_type to_increase = (to_decrease + next_step()) % count; const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max) if(seqlens[to_increase] >= seqlen_max)
{ {
continue; continue;
} }
...@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -86,10 +92,36 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
std::vector<int32_t> generate_seqstarts(mode_enum mode, std::vector<int32_t> generate_seqstarts(mode_enum mode,
unsigned count, unsigned count,
int32_t seqlen_avg, int32_t seqlen_avg,
int32_t seqlen_min = -1,
int32_t seqlen_max = -1, int32_t seqlen_max = -1,
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed)); return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int>
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>, Int>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
return dist(engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(engine); });
} }
/* /*
...@@ -112,6 +144,8 @@ decode_seqlen(mode_enum mode, ...@@ -112,6 +144,8 @@ decode_seqlen(mode_enum mode,
std::string q_val, std::string q_val,
std::string k_val, std::string k_val,
std::string k_pad_val, std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0,
bool use_kvcache = false,
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str())) #define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
...@@ -119,9 +153,36 @@ decode_seqlen(mode_enum mode, ...@@ -119,9 +153,36 @@ decode_seqlen(mode_enum mode,
{ {
ck_tile::index_t q = _S2I_(q_val); ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val); ck_tile::index_t k = _S2I_(k_val);
auto s_q = std::vector<ck_tile::index_t>(batch, q); auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k); auto s_k = [&] {
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && use_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
seed);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad); return std::make_tuple(s_q, s_k, s_kpad);
} }
else else
...@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode, ...@@ -149,6 +210,16 @@ decode_seqlen(mode_enum mode,
s_q.push_back(q); s_q.push_back(q);
s_k.push_back(k < 0 ? q : k); s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp); s_kpad.push_back(kp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
idx++; idx++;
if(found_q == std::string::npos || idx >= batch) if(found_q == std::string::npos || idx >= batch)
{ {
...@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode, ...@@ -160,8 +231,9 @@ decode_seqlen(mode_enum mode,
} }
if(idx < batch) if(idx < batch)
{ {
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed); auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed); auto rem_k =
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
...@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int) ...@@ -180,3 +252,15 @@ int env_get_int(const char* var_name, int default_int)
r = std::atoi(v); r = std::atoi(v);
return r; return r;
} }
template <typename RandomAccessIterator, typename Int>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
std::optional<unsigned> seed = std::nullopt)
{
std::iota(first, last, value);
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::shuffle(first, last, engine);
}
...@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -1039,14 +1039,14 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
return false; return false;
if constexpr(!((NDimSpatial == 1 && if constexpr(!((NDimSpatial == 1 &&
(is_NWGK_GKXC_NWGC<InLayout, WeiLayout, OutLayout>() || (is_NWGC_GKXC_NWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>())) || is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 2 && (NDimSpatial == 2 &&
(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() || (is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) || is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>())) ||
(NDimSpatial == 3 && (NDimSpatial == 3 &&
(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() || (is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>())))) is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))))
{ {
return false; return false;
} }
......
...@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle ...@@ -864,23 +864,23 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
} }
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>()) if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
{ {
return false; return false;
} }
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() || if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
{ {
return false; return false;
} }
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() || if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>())) is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{ {
return false; return false;
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp> #include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial, ...@@ -191,7 +192,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t NumGroupsToMerge = 1, index_t NumGroupsToMerge = 1,
typename ComputeTypeA = InDataType, typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA> typename ComputeTypeB = ComputeTypeA,
index_t TransposeTransferSrcScalarPerVector = 1,
index_t TransposeTransferDstScalarPerVector = 1>
struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
: public DeviceGroupedConvBwdWeight<NDimSpatial, : public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout, InLayout,
...@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -216,6 +219,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using BDataType = InDataType; using BDataType = InDataType;
using EDataType = WeiDataType; using EDataType = WeiDataType;
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||
is_same_v<ADataType, BDataType>);
using AElementwiseOperation = OutElementwiseOperation; using AElementwiseOperation = OutElementwiseOperation;
using BElementwiseOperation = InElementwiseOperation; using BElementwiseOperation = InElementwiseOperation;
using CDEElementwiseOperation = WeiElementwiseOperation; using CDEElementwiseOperation = WeiElementwiseOperation;
...@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -351,6 +359,142 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch)[I2]; batch)[I2];
} }
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& HiStride = g_n_c_wis_strides[3];
const index_t& WiStride = g_n_c_wis_strides[4];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Hi = g_n_c_wis_lengths[3];
const index_t& Wi = g_n_c_wis_lengths[4];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeInputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& GStride = g_n_c_wis_strides[0];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t& CStride = g_n_c_wis_strides[2];
const index_t& DiStride = g_n_c_wis_strides[3];
const index_t& HiStride = g_n_c_wis_strides[4];
const index_t& WiStride = g_n_c_wis_strides[5];
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto MakeOutputTransposeDesc(std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_lengths,
std::array<ck::index_t, NDimSpatial + 3> g_n_c_wis_strides)
{
const index_t& G = g_n_c_wis_lengths[0];
const index_t& N = g_n_c_wis_lengths[1];
const index_t& C = g_n_c_wis_lengths[2];
const index_t& Di = g_n_c_wis_lengths[3];
const index_t& Hi = g_n_c_wis_lengths[4];
const index_t& Wi = g_n_c_wis_lengths[5];
const index_t& NStride = g_n_c_wis_strides[1];
const index_t DiStride = Hi * Wi * G * C;
const index_t HiStride = Wi * G * C;
const index_t WiStride = G * C;
const index_t GStride = C;
const index_t CStride = 1;
const auto desc = make_naive_tensor_descriptor(
make_tuple(N, G, C, Di, Hi, Wi),
make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride));
const auto merged_desc =
transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(make_tuple(N, G, C)),
make_merge_transform(make_tuple(Di, Hi, Wi))),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return PadTensorDescriptor(
merged_desc,
make_tuple(MPerBlock / ClusterLengthMPerBlock, NPerBlock / ClusterLengthNPerBlock),
Sequence<true, true>{});
}
using InputTransposeDescType =
remove_cvref_t<decltype(MakeInputTransposeDesc<NDimSpatial>({}, {}))>;
using OutputTransposeDescType =
remove_cvref_t<decltype(MakeOutputTransposeDesc<NDimSpatial>({}, {}))>;
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>()); using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>; using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
...@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -407,13 +551,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ComputeTypeA, ComputeTypeA,
ComputeTypeB>; ComputeTypeB>;
static constexpr index_t ClusterLengthMPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
static constexpr index_t ClusterLengthNPerBlock =
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;
using GridwiseElementwise = using GridwiseElementwiseCast =
GridwiseElementwise<Tuple<CElementwiseGridDesc_M_N>, GridwiseElementwise<Tuple<CElementwiseGridDesc_M_N>,
Tuple<CElementwiseGridDesc_M_N>, Tuple<CElementwiseGridDesc_M_N>,
Tuple<const AccDataType*>, Tuple<const AccDataType*>,
...@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -431,6 +571,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1, I1,
I1>; I1>;
using GridwiseElementwiseTranspose =
GridwiseElementwise<Tuple<InputTransposeDescType>,
Tuple<OutputTransposeDescType>,
Tuple<const ADataType*>,
Tuple<ADataType*>,
Block2TileMapElementwise,
element_wise::PassThrough,
BlockSize,
MPerBlock,
NPerBlock,
MPerBlock / ClusterLengthMPerBlock,
NPerBlock / ClusterLengthNPerBlock,
Sequence<1, 0>,
Sequence<TransposeTransferSrcScalarPerVector>,
Sequence<TransposeTransferDstScalarPerVector>,
I1,
I0>;
// Argument // Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -493,6 +651,45 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end(a_g_n_k_wos_lengths), end(a_g_n_k_wos_lengths),
begin(output_spatial_lengths_)); begin(output_spatial_lengths_));
std::array<index_t, NDimSpatial + 3> b_g_n_c_wis_strides_transposed =
b_g_n_c_wis_strides;
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
a_g_n_k_wos_strides;
// NGKHW - transpose needed
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
b_g_n_c_wis_strides_transposed[I0] = Conv_C_;
b_g_n_c_wis_strides_transposed[I2] = I1;
a_g_n_k_wos_strides_transposed[I0] = Conv_K_;
a_g_n_k_wos_strides_transposed[I2] = I1;
if constexpr(NDimSpatial == 2)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] =
output_spatial_lengths_[I1] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I4] = Conv_G_ * Conv_K_;
}
else if constexpr(NDimSpatial == 3)
{
b_g_n_c_wis_strides_transposed[I3] =
input_spatial_lengths_[I1] * input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_C_;
b_g_n_c_wis_strides_transposed[I5] = Conv_G_ * Conv_C_;
a_g_n_k_wos_strides_transposed[I3] = output_spatial_lengths_[I1] *
input_spatial_lengths_[I2] * Conv_G_ *
Conv_K_;
a_g_n_k_wos_strides_transposed[I4] =
input_spatial_lengths_[I2] * Conv_G_ * Conv_K_;
a_g_n_k_wos_strides_transposed[I5] = Conv_G_ * Conv_K_;
}
}
const auto descs = const auto descs =
conv_to_gemm_transformer_v2 conv_to_gemm_transformer_v2
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>( .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
...@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -502,9 +699,9 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_spatial_lengths_, input_spatial_lengths_,
filter_spatial_lengths_, filter_spatial_lengths_,
output_spatial_lengths_, output_spatial_lengths_,
b_g_n_c_wis_strides, b_g_n_c_wis_strides_transposed,
e_g_k_c_xs_strides, e_g_k_c_xs_strides,
a_g_n_k_wos_strides, a_g_n_k_wos_strides_transposed,
conv_filter_strides, conv_filter_strides,
conv_filter_dilations, conv_filter_dilations,
input_left_pads, input_left_pads,
...@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -540,8 +737,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
// A/B/C Batch Stride // A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0];
compute_ptr_offset_of_batch_.BatchStrideC_ = compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ * Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_), std::accumulate(begin(filter_spatial_lengths_),
...@@ -553,13 +750,58 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -553,13 +750,58 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ce_grid_desc_m_n_, ce_grid_desc_m_n_,
GridwiseGemm::CalculateMBlock(GemmM), GridwiseGemm::CalculateMBlock(GemmM),
GridwiseGemm::CalculateNBlock(GemmN)); GridwiseGemm::CalculateNBlock(GemmN));
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
a_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
a_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(a_g_n_k_wos_lengths, a_g_n_k_wos_strides);
b_in_transpose_desc_ =
MakeInputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
b_out_transpose_desc_ =
MakeOutputTransposeDesc<NDimSpatial>(b_g_n_c_wis_lengths, b_g_n_c_wis_strides);
elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{
a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)};
elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{
b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)};
}
} }
std::size_t GetWorkspaceSizeBytes() const std::size_t GetWorkspaceATensorSizeBytes() const
{
return sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceBTensorSizeBytes() const
{
return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize();
}
std::size_t GetWorkspaceETensorSizeBytes() const
{ {
return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_;
} }
std::size_t GetWorkspaceSizeBytes() const
{
// Transpose require workspace for A and B
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() +
GetWorkspaceETensorSizeBytes();
}
else
{
return GetWorkspaceETensorSizeBytes();
}
}
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
...@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -571,6 +813,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
Block2TileMapElementwise elementwise_block_2_ctile_map_; Block2TileMapElementwise elementwise_block_2_ctile_map_;
Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_,
elementwise_block_2_ctile_map_transpose_b_;
InputTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_;
OutputTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
...@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -624,17 +871,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_); AccDataType* p_c_grid = type_convert<AccDataType*>(arg.p_workspace_);
const ADataType* p_a_grid = arg.p_a_grid_;
const BDataType* p_b_grid = arg.p_b_grid_;
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
p_a_grid = type_convert<const ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
p_b_grid =
type_convert<const BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType);
}
// nullptr for output, will be set after workspace set // nullptr for output, will be set after workspace set
typename GridwiseGemm::Argument gemm_arg{arg.p_a_grid_, typename GridwiseGemm::Argument gemm_arg{
arg.p_b_grid_, p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_};
p_c_grid,
GemmM,
GemmN,
GemmK,
I0,
I0,
I0,
arg.k_batch_};
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
...@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -651,8 +904,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() { const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync( hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
gemm_arg.p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); 0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
}; };
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
...@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1261,6 +1516,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
float avg_time = 0.f;
auto launch_elementwise_kernel = [&]() { auto launch_elementwise_kernel = [&]() {
const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_); const AccDataType* p_c_grid = type_convert<const AccDataType*>(arg.p_workspace_);
const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize(
...@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1270,7 +1526,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
std::array<index_t, I1> in_out_batch_strides = { std::array<index_t, I1> in_out_batch_strides = {
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise, const auto kernel = kernel_batched_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc_M_N>, ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<CElementwiseGridDesc_M_N>, ck::Tuple<CElementwiseGridDesc_M_N>,
ck::Tuple<const AccDataType*>, ck::Tuple<const AccDataType*>,
...@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1296,7 +1552,54 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
in_out_batch_strides); in_out_batch_strides);
}; };
float avg_time = RunGemmV3(arg, stream_config); if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
const index_t grid_size_a =
arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize(
arg.a_in_transpose_desc_);
const index_t grid_size_b =
arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize(
arg.b_in_transpose_desc_);
ADataType* p_a_out_grid = type_convert<ADataType*>(arg.p_workspace_) +
arg.GetWorkspaceETensorSizeBytes() / sizeof(BDataType);
BDataType* p_b_out_grid =
type_convert<BDataType*>(arg.p_workspace_) +
(arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) /
sizeof(BDataType);
auto kernel_transpose = kernel_elementwise_dual<GridwiseElementwiseTranspose,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<InputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<OutputTransposeDescType>,
ck::Tuple<const ADataType*>,
ck::Tuple<BDataType*>,
Block2TileMapElementwise,
Block2TileMapElementwise,
element_wise::PassThrough>;
avg_time += launch_and_time_kernel(stream_config,
kernel_transpose,
dim3(grid_size_a + grid_size_b),
dim3(BlockSize),
0,
make_tuple(arg.a_in_transpose_desc_),
make_tuple(arg.b_in_transpose_desc_),
make_tuple(arg.a_out_transpose_desc_),
make_tuple(arg.b_out_transpose_desc_),
make_tuple(arg.p_a_grid_),
make_tuple(arg.p_b_grid_),
make_tuple(p_a_out_grid),
make_tuple(p_b_out_grid),
arg.elementwise_block_2_ctile_map_transpose_a_,
arg.elementwise_block_2_ctile_map_transpose_b_,
element_wise::PassThrough{},
grid_size_a);
}
avg_time += RunGemmV3(arg, stream_config);
avg_time += launch_elementwise_kernel(); avg_time += launch_elementwise_kernel();
return avg_time; return avg_time;
} }
...@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1347,25 +1650,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{ {
return false; return false;
} }
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 2)
{ {
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>()) if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
{ is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>()))
return false;
}
}
else if constexpr(NDimSpatial == 2)
{
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>()))
{ {
return false; return false;
} }
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() || if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>())) is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()))
{ {
return false; return false;
} }
...@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1431,6 +1727,35 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return false; return false;
} }
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0)
{
return false;
}
if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0)
{
return false;
}
const index_t input_spatial_acum = ck::accumulate_n<index_t>(
arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t output_spatial_acum = ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
{
return false;
}
if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0)
{
return false;
}
}
return true; return true;
} }
...@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1563,8 +1888,17 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: " << "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumGroupsToMerge << NumGroupsToMerge;
<< ">";
if constexpr(is_NGCHW_GKYXC_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKZYXC_NGKDHW<InLayout, WeiLayout, OutLayout>()) {
str << ", TransposeTransferSrcScalarPerVector: "
<< TransposeTransferSrcScalarPerVector <<", "
<< "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector;
}
str << ">";
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ...@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return false; return false;
} }
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() || if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>())) is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{ {
return false; return false;
} }
......
...@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
} }
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
if constexpr(!is_GNWK_GKXC_GNWC<InLayout, WeiLayout, OutLayout>()) if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
{ {
return false; return false;
} }
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
if constexpr(!(is_NHWGK_GKYXC_NHWGC<InLayout, WeiLayout, OutLayout>() || if constexpr(!(is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNHWK_GKYXC_GNHWC<InLayout, WeiLayout, OutLayout>())) is_GNHWC_GKYXC_GNHWK<InLayout, WeiLayout, OutLayout>()))
{ {
return false; return false;
} }
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC<InLayout, WeiLayout, OutLayout>() || if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>() ||
is_GNDHWK_GKZYXC_GNDHWC<InLayout, WeiLayout, OutLayout>())) is_GNDHWC_GKZYXC_GNDHWK<InLayout, WeiLayout, OutLayout>()))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment