"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "3141ea440a843a44fc34ed06a93a29d283727583"
Unverified Commit c1569892 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Add PagedAttention kernels (#1387)



* Use dictionary to config all the functions

* Add init codegen logic for fmha fwd appendkv

* Call HIP_CHECK_ERROR() macro to get real source info

* Setup meaningfull arguments

* Sync kernel name with the codegen

* Add knew/vnew tensors to the kernel argument

* Fix wrong K values after appending

* Fix vnew append errro

* Extract common logics

* Fix Vnew tile dstr for row major case

* Conditionally add fwd_splitkv API in fmha_fwd example

* Conditionally add call to fmha_fwd_splitkv()

* Remove "EXAMPLE_" prefix of cmake variables

* Regsiter API handlers automatically

* Early return if 0 < s_k_new is not supported

* Show message if we are ignoring option

* Unify CMakeLists.txt coding style

* Set num_splits=1 if split-kv is not supported

* Add length/stride getters for HostTensor

* Add RoPE example utilities

* Add reference_rotary_position_embedding() (not implemented)

* Finish reference_rotary_position_embedding() impl

* Fix typo of HostTensor<>::get_length()

* Fix compilation errors

* Fix wrong answer when interleaved=false

* Fix wrong answer when interleaved=true

* Append K/V in the host verification code

* Simplify K appending logics

* Simplify v_host_ref definition

* Reduce input/output dimensions

* Rename function: add "batched" prefix

* Apply RoPE on host side

* Rename RoPE utility function

* Fix wrong tensor size

* Avoid invoking deprecated method 'find_module'

* Pass RoPE kernel args

* Create Rotary Cos/Sin tile windows in kernel

* Add compute data type alias for RoPE

* Randomly generate seqlen_knew if needed

* Fix seqlen_knew enabling check logic

* Add minimum seqlen_k to generate compliance kvcache

* Fix compilation error in debug mode

* Fix wrong boundaries

* Fix wrong seqlen_k for kvcache

* Rename variables used in distributio encoding

* Fix rotary cos/sin tensor/tile size

* Add constraint to the rotary_dim option

* Remove unused inner namespace

* Add dram distribution for rotary_cos/rotary_sin (interleaved)

* Only apply interleaved RoPE on Knew for now

* Fix wrong thread starting offset

* Instantiate multiple kernels for RoPE approaches

* Clean-up pipeline

* Fix error in RoPE host reference

* Handle RoPE half-rotated logics

* Support 8x rotary_dim under half-rotated RoPE

* Add comment

* Apply elementwise function to the loaded tiles

* Unify parameter/variable naming style

* Remove constness from q_ptr

* Add code blocks for q_tile

* Apply RoPE to q_tile

* Remove debug print code in kernel

* Fix wrong knew/vnew appending positions

* Use better naming for tile indices

* Add make_tile_window() for adding distribution only

* Skip code if # of block is more than needed

* Move thread locating logics into policy

* Remove always true static_assert()

* Rename header

* Rename RotaryEmbeddingEnum

* Extract rotary embedding logic out

* Re-order parameters

* Align naming of some tile size constants

* Rename more tile size constants

* Fix wrong grid size

* Fix wrong shape of knew_host/vnew_host

* Fix wrong index into knew_host/vnew_host

* Fix wrong rotary_cos/rotary_sin memory size for Q

* Extract Q/Knew vector size to helper methods

* Use different rotary_cos/rotary_sin distr for Q/Knew

* Update host/device specifiers

* Fix wrong data type for Q rotary_cos/rotary_sin

* Remove RoPEComputeDataType type alias

* Shift rotary_cos/rotary_sin by cache_seqlen_k

* Add comment for why I just 't' for all padding flags

* Align commit message to the real comment

* Fix wrong pipeline

* Rename utility function

* Disable host verification if API not exist

* Fix wrong rope key for fp8 pipeline

* Allow only apply RoPE on Q (without append KV)

* Add append-kv smoke tests

* Remove debug statements

* Remove more debug statements

* Re-arrange the 'set +x' command

* Remove no-longer used method in pipeline

* Add missing init code

* Refine pipeline padding settings

* Enlarge rotary_dim limit (8 -> 16)

* Enlarge KPerThread for rotary_interleaved=false

* Update rotary_dim range in smoke_test_fwd.sh

* Add template argument 'kIsPagedKV' for splitkv kernels

* Launch splitkv kernel if given page_block_size

* Fix wrong kernel name

* Fix seqlen_k_min for pre-fill case (1 -> 0)

* Add copy_const<> type trait

* Add another make_tile_window()

* Introduce 'TileWindowNavigator' types

* Simplify TileWindowNavigator interfaces

* Fix tile window navigation bugs

* Disable calling fmha_fwd()

* Remove ununnecessary data members

* Simplify more make_tile_window() overloads

* Move V tile through TileWindowNavigator

* Fix uneven split checking logic

* Move code after decide seqlen_q/seqlen_k

* Make sure we always start reading complete tile

* Use 128 as minimus page_block_size

* Fix wrong origin for bias

* Add batch_stride_k/batch_stride_v in group mode

* Unify origin

* Add missing kernel arguments for group mode

* Add paged-kv codegen logic for appendkv kernels

* Add block_table kernel args for appendkv kernel

* Add tile navigators to the appendkv kernel

* Fix wrong tensor descriptor lengths

* Pass re-created tile window to pipeline

* Fix wrong strides for appendkv kernel

* Allow transit tile_window to another page-block

* Handle cross-page-block write

* Donot perform write again if already in last page-block

* Always add fmha_fwd() api

* Add missing group mode argument

* Remove debug macro usages

* Rename option s_k_new to s_knew

* Separate splitkv/non-splitkv args/traits

* Remove fmha_fwd_dispatch()

* Fix compilation errors

* Remove dropout code in splitkv kernel

* Allow problem types without define kHasDropout attr

* Use generic lambda to init traits objects

* Separate more non-splitkv & splitkv traits/args

* Display more info for specific kernels

* Show more detailed warning message

* Rename 'max_num_blocks' to 'max_num_page_blocks'

* Remove no-longer used pipeline files

* Wrap code by #if directives

* Move functors to the begining of validation code

* Use generic lambda to init all the api traits/args

* Fix wrong seqlen for kvcache

* Add missing comment

* Rename TileWindowNavigator to PageBlockNavigator

* Only expose necessary methods (not attributes)

* Re-order pipeline paremeters

* Refine smoke_test_fwd.sh

* Fix wrong arugment count

* Make tile window directly via PageBlockNavigator

* Remove unused template paremeter

* Remove group mode from appendkv kernel

* Fix skcheck logic

* Fix wrong syntax in skcheck expr

* Use meaningful options in smoke test

* Remove options

* Fix formatting

* Fix more format

* Re-organize bash functions

* Pass cache_batch_idx to kernels

* Support cache_batch_idx in example

* Fix compilation error

* Add more appendkv test

* Add more case for appendkv

* Fix unexisted attribute

* Remove 0 < seqlen_knew constraint

* Clarify the case in warning message

* Remove macro checking

* Force batch mode when invoking appendkv & splitkv apis

* Fix mode overriding logics

* Fix wrong parameter name

* Randomize seqlen_k if use kvcache

* Use randomized seqlen_k for kvcache

* Avoid using too small rotary_cos & rotary_sin

* Rename parameter

* Add seqlen_q & seqlen_k rules

* Add comment

* Add more comments

* Fix compilation errors

* Fix typo in comment

* Remove type argument

* Avoid seqlen_k=0 for kvcache

* Revert "Avoid seqlen_k=0 for kvcache"

This reverts commit 21c4df89e416182e8e9bc78e67bd4b98dbb6c88d.

* Fix wrong uneven split checking logics

* Only randomize kvcache seqlen_k if 1 < batch

* Return earlier if split is empty

* Revert "Only randomize kvcache seqlen_k if 1 < batch"

This reverts commit b9a4ab0d7e3c2beecc0fccafd2a13259dd06299c.

* Re-order seqlen_k_start adjustment logics

* Fix compilation errors

* Re-format script

* Find executable from folder automatically

* Fix kvcache seqlen_k generating logic

* Make comment more clear

* Fix wrong knew/vew appending logic on host

* Add s_barrier to sync threads

* Revert "Add s_barrier to sync threads"

This reverts commit d3f550f30c0a4d9df15c613015d5dff268d6746d.

* Support only using 1 row of rotary_cos/rotary_sin

* Rotate Q in different way

* Unify tensor view creation logics

* Fix wrong argument

* Add mask to switch how we use the rotary_cos/sin

* Move attr from traits to problem

* Move has_mask to fmha_fwd_appendkv_args

* Support use uint32_t as SAD operand in Alibi<>

* Use sad_u32() in splitkv kernels

* Store tensor views in PageBlockNavigator

* Use stored tensor view to update tile windows

* Enlarge tensor view size

* Remove debug code

* Fix wrong tensor view size

* Wrap tensor view into PageBlockNavigator

* Add DataType member to PageBlockNavigator

* Remove unnecessary member functions

* Refind macro use

* Fix typo

* Add blank line between directives and actual code

* Re-format files

* Remove type in comment

---------
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent 19d22e60
# generate a list of kernels, but not actually emit files at config stage # validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()
foreach(api ${FMHA_FWD_ENABLE_APIS})
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
endif()
endforeach()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
) )
execute_process( execute_process(
...@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) ...@@ -17,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api fwd,fwd_splitkv --output_dir ${CMAKE_CURRENT_BINARY_DIR} --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
add_custom_command( add_custom_command(
...@@ -60,6 +80,20 @@ else() ...@@ -60,6 +80,20 @@ else()
endif() endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
endif()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
endif()
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
......
...@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = { ...@@ -82,6 +82,18 @@ DROPOUT_CHECK_MAP = {
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true", "dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
} }
ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
}
ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
}
MODE_MAP = { MODE_MAP = {
"batch" : "false", "batch" : "false",
"group" : "true" "group" : "true"
...@@ -105,4 +117,4 @@ PIPELINE_ENUM_MAP = { ...@@ -105,4 +117,4 @@ PIPELINE_ENUM_MAP = {
BOOL_MAP = { BOOL_MAP = {
"t" : "true", "t" : "true",
"f" : "false" "f" : "false"
} }
\ No newline at end of file
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.ops.fmha_fwd import (
FmhaFwdApiTrait,
DTYPE_BITS,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
FMHA_FWD_APPENDKV_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_occupancy}>;
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
{F_bs},
{F_bsk},
{F_bd},
{F_bdv},
{F_vlayout},
{F_rope},
{F_pagedkv},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
fmha_pipeline_problem_{F_idx}>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
fmha_pipeline_{F_idx}>;
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API="""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
bs : int # tile size along q seqlen
bsk : int # tile size along k seqlen
bd : int # tile size along qk gemm unroll
bdv : int # tile size along kv gemm unroll
vlayout : str
spad : str
skpad : str
dpad : str
dvpad : str
rope : str # key from ROPE_MAP
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
else : return f'a.seqlen_q % {self.bs} == 0'
@property
def skcheck(self) -> str:
# we do not check all the values in a.seqlen_k_ptr
return 'true'
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bd} == 0'
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bdv} == 0'
@dataclass
class FmhaFwdAppendKVPipeline:
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # key from ROPE_MAP
F_pagedkv : str # t/f
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope != 'no': n += f'_{self.F_rope}'
if self.F_pagedkv == 't': n += '_pagedkv'
return n
class FmhaFwdAppendKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdAppendKVTileSize:
F_bs : int # tile size along q seqlen
F_bsk : int # tile size along k seqlen
F_bd : int # tile size along qk gemm unroll
F_bdv : int # tile size along kv gemm unroll
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdAppendKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaFwdAppendKVTileSize
F_pipeline : FmhaFwdAppendKVPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
F_bdv = self.F_tile.F_bdv,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
return FmhaFwdAppendKVApiTrait(
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
}
else:
return None
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
continue
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_appendkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
\ No newline at end of file
...@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import ( ...@@ -21,6 +21,14 @@ from codegen.ops.fmha_fwd import (
) )
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync",
...@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, ...@@ -51,8 +59,8 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
{F_bias}, {F_bias},
false, false,
{F_lse}, {F_lse},
{F_dropout},
{F_squant}, {F_squant},
{F_pagedkv},
kHasUnevenSplits, kHasUnevenSplits,
{F_occupancy}>; {F_occupancy}>;
...@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< ...@@ -63,7 +71,6 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
...@@ -86,7 +93,7 @@ using fmha_kernel = ...@@ -86,7 +93,7 @@ using fmha_kernel =
fmha_pipeline, fmha_pipeline,
fmha_epilogue>; fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
using k_ = fmha_kernel; using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
...@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a) ...@@ -97,16 +104,21 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_args a)
}}; }};
}} }}
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream> #include <iostream>
template<> template<>
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a) void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if constexpr({F_mode} == false) {{ // batch mode if constexpr({F_mode} == false) {{ // batch mode
if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ // we don't check every seqlen_k values for kvcache
if (a.seqlen_k_ptr != nullptr) {{
kernel_runner<true>::run(s, a);
// make sure F_bn0 is divisible by F_bk1
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
kernel_runner<false>::run(s, a); kernel_runner<false>::run(s, a);
}} else {{ }} else {{
kernel_runner<true>::run(s, a); kernel_runner<true>::run(s, a);
...@@ -160,7 +172,7 @@ using fmha_kernel = ...@@ -160,7 +172,7 @@ using fmha_kernel =
fmha_pipeline, fmha_pipeline,
fmha_epilogue>; fmha_epilogue>;
static void run(const ck_tile::stream_config& s, fmha_fwd_args a) static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
using k_ = fmha_kernel; using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a); auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
...@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m ...@@ -177,7 +189,7 @@ using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_m
#include <iostream> #include <iostream>
template<> template<>
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a) void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if (a.num_splits <= 16) {{ if (a.num_splits <= 16) {{
kernel_runner<4>::run(s, a); kernel_runner<4>::run(s, a);
...@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API=""" ...@@ -203,7 +215,7 @@ FMHA_FWD_SPLITKV_API="""
#include <iostream> #include <iostream>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_> template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{ {{
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout std::cout
...@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a) ...@@ -217,22 +229,96 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_args a)
); );
}} }}
float fmha_fwd_splitkv(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
float r = -1; float r = -1;
{F_dispatch} {F_dispatch}
return r; return r;
}} }}
""" """
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using traits_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>; using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}/2, {F_bn1}, {F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
return fmha_fwd_splitkv_<traits_, traits2_>(s, a); return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
}} }}
""" """
@dataclass
class FmhaFwdSplitKVApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0blen : int
vlayout : str
mask : str
bias : str #
lse : str #
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
pagedkv : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
f'{self.dvpad}-{self.pagedkv}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@dataclass @dataclass
class FmhaFwdSplitKVPipeline: class FmhaFwdSplitKVPipeline:
tag : str tag : str
...@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline: ...@@ -244,8 +330,8 @@ class FmhaFwdSplitKVPipeline:
F_dvpad : str # F_dvpad : str #
F_bias : str # true/false F_bias : str # true/false
F_lse : str # F_lse : str #
F_dropout : str #
F_squant : str # F_squant : str #
F_pagedkv : str # t/f
F_mask : str # value from MASK_MAP F_mask : str # value from MASK_MAP
@property @property
...@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline: ...@@ -267,8 +353,8 @@ class FmhaFwdSplitKVPipeline:
else: else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse' if self.F_lse == 't' : n += '_lse'
if self.F_dropout == 't' : n += '_dropout'
if self.F_squant == 't' : n += '_squant' if self.F_squant == 't' : n += '_squant'
if self.F_pagedkv == 't' : n += '_pagedkv'
return n return n
@dataclass @dataclass
...@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool: ...@@ -300,7 +386,7 @@ class FmhaFwdSplitKVApiPool:
self.pool = dict() self.pool = dict()
self.mask_impl = mask_impl self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None: def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None:
# TODO: do we need to check duplication? # TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys(): if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict() self.pool[trait.dtype] = dict()
...@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool: ...@@ -322,8 +408,8 @@ class FmhaFwdSplitKVApiPool:
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
...@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel: ...@@ -383,8 +469,8 @@ class FmhaFwdSplitKVKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy, F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
...@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel: ...@@ -401,8 +487,8 @@ class FmhaFwdSplitKVKernel:
def filename(self) -> str: def filename(self) -> str:
return self.name + ".cpp" return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait: def api_trait(self) -> FmhaFwdSplitKVApiTrait:
return FmhaFwdApiTrait( return FmhaFwdSplitKVApiTrait(
pipeline_tag=self.F_pipeline.tag, pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim), hdim=str(self.F_hdim),
dtype=self.F_dtype, dtype=self.F_dtype,
...@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel: ...@@ -417,8 +503,8 @@ class FmhaFwdSplitKVKernel:
mask=self.F_pipeline.F_mask, mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias, bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse, lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant, squant=self.F_pipeline.F_squant,
pagedkv=self.F_pipeline.F_pagedkv,
spad=self.F_pipeline.F_spad, spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad, skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad, dpad=self.F_pipeline.F_dpad,
...@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -460,29 +546,6 @@ class FmhaFwdSplitKVCombineKernel:
def filename(self) -> str: def filename(self) -> str:
return self.name + ".cpp" return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0blen=self.F_tile.F_bk0blen,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad)
# TODO: design a more practical way to do it # TODO: design a more practical way to do it
# this is current supported tile size per hdim # this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
...@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -533,27 +596,27 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f' squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
# splitkv kernel donot support dropout for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["f"]): # TODO: use async pipeline when compiler is more stable
if hdim == 256: if hdim == 256 or hdim in [32, 64, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
else: else:
pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask))
pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
if receipt == 1: if receipt == 1:
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels # no need lse/paged-kv kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, 'f', mask))
else: else:
assert False assert False
return pipelines return pipelines
...@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -574,6 +637,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if pipeline.F_spad != 't' or pipeline.F_skpad != 't': if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue continue
if pipeline.F_pagedkv == 't':
# we only use batch mode kernels to handle (paged-) kvcache problems
continue
k = Kernel(F_idx=0, k = Kernel(F_idx=0,
F_hdim=hdim, F_hdim=hdim,
F_dtype=dtype, F_dtype=dtype,
......
This diff is collapsed.
...@@ -5,10 +5,13 @@ ...@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp" #include "ck_tile/ops/fmha.hpp"
#include "bias.hpp" #include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits> #include <type_traits>
template <typename DataType> template <typename DataType>
...@@ -93,13 +96,86 @@ struct fmha_fwd_args ...@@ -93,13 +96,86 @@ struct fmha_fwd_args
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr; void* rand_val_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
float scale_p;
float scale_o;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
struct fmha_fwd_splitkv_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* lse_acc_ptr; void* lse_acc_ptr;
void* o_acc_ptr; void* o_acc_ptr;
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// kvcache mode (use same kernel as batch mode):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
const void* seqstart_k_ptr; const void* seqstart_k_ptr;
const void* seqlen_k_ptr; const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k; ck_tile::index_t seqlen_k;
ck_tile::index_t batch; ck_tile::index_t batch;
...@@ -109,21 +185,21 @@ struct fmha_fwd_args ...@@ -109,21 +185,21 @@ struct fmha_fwd_args
ck_tile::index_t nhead_q; ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k; ck_tile::index_t nhead_k;
ck_tile::index_t num_splits; ck_tile::index_t num_splits;
float scale_s; float scale_s;
float scale_p; float scale_p;
float scale_o; float scale_o;
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o_acc; ck_tile::index_t stride_o_acc;
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
...@@ -132,19 +208,62 @@ struct fmha_fwd_args ...@@ -132,19 +208,62 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_lse_acc; ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc; ck_tile::index_t split_stride_o_acc;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop; };
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset; struct fmha_fwd_appendkv_args
{
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
const void* vnew_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_knew;
ck_tile::index_t batch;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
const void* rotary_cos_ptr; // only used if 'rotary_dim' > 0
const void* rotary_sin_ptr; // only used if 'rotary_dim' > 0
ck_tile::index_t rotary_dim;
bool has_mask;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
const void* cache_batch_idx;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
ck_tile::index_t stride_v;
ck_tile::index_t stride_vnew;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_knew;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
}; };
template <typename FmhaKernel> template <typename FmhaKernel>
...@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -244,7 +363,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
} }
template <typename Kernel> template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{ {
assert(args.nhead_q % args.nhead_k == 0); assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] { auto kargs = [&] {
...@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -255,11 +374,9 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr, args.lse_acc_ptr,
args.o_acc_ptr, args.o_acc_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr, args.seqstart_q_ptr,
args.seqstart_k_ptr, args.seqstart_k_ptr,
args.seqlen_k_ptr, args.seqlen_k_ptr,
...@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -274,24 +391,22 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o_acc, args.stride_o_acc,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type);
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
...@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -299,48 +414,45 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_acc_ptr, args.lse_acc_ptr,
args.o_acc_ptr, args.o_acc_ptr,
args.batch, args.batch,
args.max_seqlen_q,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q, args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.num_splits, args.num_splits,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
args.stride_q, args.stride_q,
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o_acc, args.stride_o_acc,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse_acc, args.nhead_stride_lse_acc,
args.nhead_stride_o_acc, args.nhead_stride_o_acc,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse_acc, args.batch_stride_lse_acc,
args.batch_stride_o_acc, args.batch_stride_o_acc,
args.split_stride_lse_acc, args.split_stride_lse_acc,
args.split_stride_o_acc, args.split_stride_o_acc,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type, args.mask_type);
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
}(); }();
...@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args) ...@@ -351,7 +463,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_args args)
} }
template <typename Kernel> template <typename Kernel>
auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{ {
assert(args.nhead_q % args.nhead_k == 0); assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] { auto kargs = [&] {
...@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args) ...@@ -410,6 +522,51 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_args args)
return ck_tile::make_tuple(kargs, grids); return ck_tile::make_tuple(kargs, grids);
} }
template <typename Kernel>
auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.knew_ptr,
args.v_ptr,
args.vnew_ptr,
args.seqlen_q,
args.seqlen_k_ptr,
args.seqlen_knew,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.has_mask,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.stride_q,
args.stride_k,
args.stride_knew,
args.stride_v,
args.stride_vnew,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_knew,
args.nhead_stride_v,
args.nhead_stride_vnew,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_knew,
args.batch_stride_v,
args.batch_stride_vnew);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.seqlen_knew);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel // this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_, template <ck_tile::index_t HDim_,
typename DataType_, typename DataType_,
...@@ -458,8 +615,52 @@ struct fmha_fwd_traits_ ...@@ -458,8 +615,52 @@ struct fmha_fwd_traits_
template <typename Traits_> template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kDoFp8StaticQuant_,
bool kIsPagedKV_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_splitkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_> template <typename Traits_>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_fwd_splitkv_get_name_(); std::string fmha_fwd_splitkv_get_name_();
...@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_ ...@@ -487,11 +688,45 @@ struct fmha_fwd_splitkv_combine_traits_
}; };
template <typename Traits_> template <typename Traits_>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_args); void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_> template <typename Traits_>
std::string fmha_fwd_splitkv_combine_get_name_(); std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
ck_tile::index_t kTileSizeS_,
ck_tile::index_t kTileSizeSk_,
ck_tile::index_t kTileSizeD_,
ck_tile::index_t kTileSizeDv_,
bool kIsVLayoutRowMajor_,
bool kPadS_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_,
ck_tile::RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct fmha_fwd_traits struct fmha_fwd_traits
{ {
...@@ -508,4 +743,32 @@ struct fmha_fwd_traits ...@@ -508,4 +743,32 @@ struct fmha_fwd_traits
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
float fmha_fwd_splitkv(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
struct fmha_fwd_splitkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,
fmha_fwd_splitkv_args,
const ck_tile::stream_config&);
struct fmha_fwd_appendkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_v_rowmajor;
rope_enum rope_type;
};
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,
const ck_tile::stream_config&);
...@@ -5,25 +5,30 @@ ...@@ -5,25 +5,30 @@
import argparse import argparse
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
import pkgutil
import sys
from typing import List, Optional from typing import List, Optional
import codegen.ops
from codegen.cmake_config import * from codegen.cmake_config import *
from codegen.ops import (
fmha_fwd,
fmha_fwd_splitkv,
fmha_bwd
)
class HandlerId(IntEnum): class HandlerId(IntEnum):
LIST_BLOBS = 0 LIST_BLOBS = 0
WRITE_BLOBS = 1 WRITE_BLOBS = 1
handlers = { # inspect all modules under 'codegen.ops' and register API handlers
'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), ops = []
'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
} if full_module_name not in sys.modules:
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
(op.list_blobs, op.write_blobs)) for op in ops]
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None: if output_dir is None:
...@@ -103,4 +108,4 @@ if __name__ == "__main__": ...@@ -103,4 +108,4 @@ if __name__ == "__main__":
if args.list_blobs is not None: if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask) list_blobs(args.list_blobs, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
else: else:
write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask) write_blobs(args.output_dir, api_list, args.filter, int(args.receipt), mask_impl=args.mask)
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <cmath>
#include <functional>
#include <iterator>
#include <optional>
#include <random>
#include <tuple>
// keep sync with RotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen,
ck_tile::index_t rotary_dim,
std::optional<unsigned> seed = std::nullopt)
{
// return dummy tensors if we won't apply RoPE at all
if(rotary_dim <= 0)
{
ck_tile::HostTensor<DataType> dummy({1, 1});
return std::make_tuple(dummy, dummy);
}
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
const ck_tile::index_t num_rows = seqlen * 2;
const ck_tile::index_t num_cols = rotary_dim / 2;
using std::begin, std::end;
ck_tile::HostTensor<float> angle({num_rows, num_cols});
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::cos(origin_value));
});
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
return ck_tile::type_convert<DataType>(std::sin(origin_value));
});
return std::make_tuple(cos, sin);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
const ck_tile::HostTensor<DataType>& sin,
ck_tile::index_t seqlen_offset,
ck_tile::index_t seqlen)
{
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
const ck_tile::index_t num_rows = seqlen;
const ck_tile::index_t num_cols = cos.get_length(1);
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
return std::make_tuple(cos_pt, sin_pt);
}
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_bwd
VALID=0 VALID=0
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
......
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_fwd
VALID=0 VALID=0
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
......
#!/bin/sh #!/bin/sh
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_bwd
KNAME=1 KNAME=1
export CK_WARMUP=0 export CK_WARMUP=0
......
#!/bin/sh #!/bin/bash
# TODO: run this script from CK root # TODO: run this script from CK root or build directory
BUILD=build EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
EXE=$BUILD/bin/tile_example_fmha_fwd
KNAME=1 KNAME=1
export CK_WARMUP=0 export CK_WARMUP=0
...@@ -10,44 +9,98 @@ export CK_REPEAT=1 ...@@ -10,44 +9,98 @@ export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1' COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0 # mode=0
# export HIP_VISIBLE_DEVICES=4 # export HIP_VISIBLE_DEVICES=4
set -x
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done TEST_SPLITKV=0
done TEST_APPENDKV=0
done # options:
done # -s: run splitkv tests
done # -a: run appendkv tests
done while getopts ":sa" opt; do
done case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done done
run_fp16_bf16_tests() {
local NUM_SPLITS=(1)
local PAGE_BLOCK_SIZE=(0)
local CACHE_BATCH_IDX=(0)
for perm in 0 1 ; do if [ $TEST_SPLITKV -eq 1 ] ; then
for bias in "n" "e" "a" ; do NUM_SPLITS+=(2 3)
for b in 1 2 ; do PAGE_BLOCK_SIZE+=(128)
for hdim in 64 128 256 ; do CACHE_BATCH_IDX+=(1)
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS fi
done
done for prec in "fp16" "bf16" ; do
done for mode in 1 0 ; do
done for perm in 0 1 ; do
set +x for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ;
}
run_fp8_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp16_appendkv_tests() {
for s in $(seq 63 1 65) ; do
for s_k in 65 129 ; do
for s_knew in 0 64 $s_k ; do
for hdim in 32 64 128 256 ; do
for ri in 0 1 ; do
for rdim in 0 16 32 $hdim ; do
for page_block_size in 0 128 ; do
for cache_batch_idx in 0 1 ; do
$EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done
}
set -x
run_fp16_bf16_tests
run_fp8_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests
fi
set +x
\ No newline at end of file
...@@ -3,15 +3,17 @@ ...@@ -3,15 +3,17 @@
#pragma once #pragma once
#include <algorithm>
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <functional>
#include <optional> #include <optional>
#include <ostream> #include <ostream>
#include <sstream>
#include <string>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional>
#include <string>
#include "ck_tile/core/container/span.hpp" #include "ck_tile/core/container/span.hpp"
...@@ -37,18 +39,21 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens) ...@@ -37,18 +39,21 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
return seqstarts; return seqstarts;
} }
std::vector<int32_t> generate_seqlens(mode_enum mode, std::vector<int32_t> generate_seqlens(unsigned count,
unsigned count,
int32_t seqlen_avg, int32_t seqlen_avg,
int32_t seqlen_min = -1, // if not negative, clamp min
int32_t seqlen_max = -1, // if not negative, clamp max int32_t seqlen_max = -1, // if not negative, clamp max
std::optional<unsigned> seed = std::nullopt) std::optional<unsigned> seed = std::nullopt)
{ {
assert(0 < count); assert(0 < count);
std::vector<int32_t> seqlens( seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
count, seqlen_max > 0 ? (seqlen_avg < seqlen_max ? seqlen_avg : seqlen_max) : seqlen_avg); seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
if(mode == mode_enum::group && 1 < count) std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(1 < count)
{ {
using size_type = std::vector<int32_t>::size_type; using size_type = std::vector<int32_t>::size_type;
...@@ -62,15 +67,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -62,15 +67,15 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat) for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{ {
const size_type to_decrease = next_idx(); const size_type to_decrease = next_idx();
// make sure each elements of seqlens is always greater than 0 // make sure each elements of seqlens is always greater than seqlen_min
if(seqlens[to_decrease] == 1) if(seqlens[to_decrease] == seqlen_min)
{ {
continue; continue;
} }
const size_type to_increase = (to_decrease + next_step()) % count; const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlen_max > 0 && seqlens[to_increase] >= seqlen_max) if(seqlens[to_increase] >= seqlen_max)
{ {
continue; continue;
} }
...@@ -83,13 +88,29 @@ std::vector<int32_t> generate_seqlens(mode_enum mode, ...@@ -83,13 +88,29 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
return seqlens; return seqlens;
} }
std::vector<int32_t> generate_seqstarts(mode_enum mode, // return random integer generated uniformly in range [low, high]
unsigned count, template <typename Int = int>
int32_t seqlen_avg, auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
int32_t seqlen_max = -1, -> std::enable_if_t<std::is_integral_v<Int>, Int>
std::optional<unsigned> seed = std::nullopt) {
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
return dist(engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>>
{ {
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_max, seed)); std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(engine); });
} }
/* /*
...@@ -112,16 +133,45 @@ decode_seqlen(mode_enum mode, ...@@ -112,16 +133,45 @@ decode_seqlen(mode_enum mode,
std::string q_val, std::string q_val,
std::string k_val, std::string k_val,
std::string k_pad_val, std::string k_pad_val,
std::optional<unsigned> seed = std::nullopt) ck_tile::index_t seqlen_k_min = 0,
bool use_kvcache = false,
std::optional<unsigned> seed = std::nullopt)
{ {
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str())) #define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if(mode == mode_enum::batch) if(mode == mode_enum::batch)
{ {
ck_tile::index_t q = _S2I_(q_val); ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val); ck_tile::index_t k = _S2I_(k_val);
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = std::vector<ck_tile::index_t>(batch, k < 0 ? q : k); auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = [&] {
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && use_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
seed);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad); return std::make_tuple(s_q, s_k, s_kpad);
} }
else else
...@@ -149,6 +199,16 @@ decode_seqlen(mode_enum mode, ...@@ -149,6 +199,16 @@ decode_seqlen(mode_enum mode,
s_q.push_back(q); s_q.push_back(q);
s_k.push_back(k < 0 ? q : k); s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp); s_kpad.push_back(kp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
idx++; idx++;
if(found_q == std::string::npos || idx >= batch) if(found_q == std::string::npos || idx >= batch)
{ {
...@@ -160,8 +220,9 @@ decode_seqlen(mode_enum mode, ...@@ -160,8 +220,9 @@ decode_seqlen(mode_enum mode,
} }
if(idx < batch) if(idx < batch)
{ {
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), s_kpad.back(), seed); auto rem_q = generate_seqlens(batch - idx, s_q.back(), 1, s_kpad.back(), seed);
auto rem_k = generate_seqlens(mode, batch - idx, s_k.back(), s_kpad.back(), seed); auto rem_k =
generate_seqlens(batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end()); s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end()); s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
...@@ -180,3 +241,15 @@ int env_get_int(const char* var_name, int default_int) ...@@ -180,3 +241,15 @@ int env_get_int(const char* var_name, int default_int)
r = std::atoi(v); r = std::atoi(v);
return r; return r;
} }
template <typename RandomAccessIterator, typename Int>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
std::optional<unsigned> seed = std::nullopt)
{
std::iota(first, last, value);
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::shuffle(first, last, engine);
}
...@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); }; ...@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST CK_TILE_HOST
float log(float x) { return std::logf(x); }; float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
{ {
// TODO: this is hacky, we use u16
return __builtin_amdgcn_sad_u16(x, y, acc); return __builtin_amdgcn_sad_u16(x, y, acc);
} }
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) CK_TILE_DEVICE uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
{
/// TODO: replace inline asm when intrinsic is available
uint32_t res;
asm volatile("v_sad_u32 %0, %1, %2, %3" : "=v"(res) : "v"(x), "v"(y), "v"(acc));
return res;
}
CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
{ {
return (x > y ? (x - y) : (y - x)) + acc; return (x > y ? (x - y) : (y - x)) + acc;
} }
......
...@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution ...@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate // move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
...@@ -843,6 +849,17 @@ struct tile_window_with_static_lengths ...@@ -843,6 +849,17 @@ struct tile_window_with_static_lengths
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
{
window_origin_ = new_window_origin;
}
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move window-origin // move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; } CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; }
...@@ -871,6 +888,39 @@ make_tile_window(const TensorView_& tensor_view, ...@@ -871,6 +888,39 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view, window_lengths, origin}; tensor_view, window_lengths, origin};
} }
// duplicate tile window and replace its origin
template <typename TensorView, typename WindowLengths>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin)
{
return tile_window_with_static_lengths<TensorView, WindowLengths>{
tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution)
{
return make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution);
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
CK_TILE_DEVICE constexpr auto
make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution)
{
return make_tile_window(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution);
}
template <typename TensorView_, typename WindowLengths_> template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window( CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window, tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
......
...@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>; ...@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template <typename T> template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type; using remove_pointer_t = typename std::remove_pointer<T>::type;
template <typename From, typename To>
struct copy_const
{
static_assert(!std::is_const_v<From>);
using type = To;
};
template <typename From, typename To>
struct copy_const<const From, To>
{
using type = std::add_const_t<typename copy_const<From, To>::type>;
};
template <typename From, typename To>
using copy_const_t = typename copy_const<From, To>::type;
namespace detail { namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args> template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector struct detector
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
......
...@@ -155,7 +155,12 @@ struct HostTensorDescriptor ...@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return space; return space;
} }
std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
const std::vector<std::size_t>& get_lengths() const { return mLens; } const std::vector<std::size_t>& get_lengths() const { return mLens; }
std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
const std::vector<std::size_t>& get_strides() const { return mStrides; } const std::vector<std::size_t>& get_strides() const { return mStrides; }
template <typename... Is> template <typename... Is>
...@@ -325,8 +330,12 @@ struct HostTensor ...@@ -325,8 +330,12 @@ struct HostTensor
{ {
} }
std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); }
decltype(auto) get_lengths() const { return mDesc.get_lengths(); } decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
decltype(auto) get_strides() const { return mDesc.get_strides(); } decltype(auto) get_strides() const { return mDesc.get_strides(); }
std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); } std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
......
...@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) ...@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{ {
// clang-format off // clang-format off
if(!s.time_kernel_) { if(!s.time_kernel_) {
(callables(s),...); hip_check_error(hipGetLastError()); (callables(s),...); HIP_CHECK_ERROR(hipGetLastError());
return 0; return 0;
} }
if(s.is_gpu_timer_) { if(s.is_gpu_timer_) {
gpu_timer timer {}; gpu_timer timer {};
// warmup // warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.start(s.stream_id_); timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.stop(s.stream_id_); timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_; return timer.duration() / s.nrepeat_;
...@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables) ...@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
cpu_timer timer {}; cpu_timer timer {};
// warmup // warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.start(s.stream_id_); timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError()); for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } HIP_CHECK_ERROR(hipGetLastError());
timer.stop(s.stream_id_); timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_; return timer.duration() / s.nrepeat_;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace ck_tile {
template <typename DataType, typename ComputeDataType = float>
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<DataType>& input_bsd,
const HostTensor<DataType>& cos_sd,
const HostTensor<DataType>& sin_sd,
bool interleaved,
HostTensor<DataType>& output_bsd,
bool use_1_row_sin_cos = false)
{
assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
cos_sd.get_length(1) == sin_sd.get_length(1));
const index_t rotary_dim = cos_sd.get_length(1) * 2;
assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
output_bsd.ForEach([&](auto& self, auto i) {
const index_t i_d = i[2];
if(rotary_dim <= i_d)
{
self(i) = input_bsd(i);
return;
}
assert(i_d < rotary_dim);
const index_t i_s = i[1];
const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
const ComputeDataType cos = type_convert<ComputeDataType>(
interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
: cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
: sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];
if(interleaved)
{
const bool is_even = (i_d % 2 == 0);
const index_t pos = i_d + (is_even ? 1 : -1);
const ComputeDataType sign = (is_even ? -1 : 1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
else
{
const index_t half_rdim = (rotary_dim / 2);
const index_t pos = (i_d + half_rdim) % rotary_dim;
const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
}
}();
ComputeDataType result =
type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
self(i) = type_convert<DataType>(result);
});
}
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment