"vscode:/vscode.git/clone" did not exist on "f75b8aa9dd4a982dacaa94b9fd3bddcdcb821632"
Unverified Commit 989ffa2b authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #263 from ROCm/develop

Merge from public repo
parents 0dd5d62a cabbacb6
We'd love for you to contribute to our source code!
Some helpful links:
- [Code of Conduct guidelines](https://www.contributor-covenant.org/version/2/1/code_of_conduct/code_of_conduct.txt)
- [New issue guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/ISSUE_TEMPLATE.md)
- [Submitting a pull request guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/PULL_REQUEST_TEMPLATE.md)
- [Maintainers](https://github.com/rocm/composable_kernel/blob/develop/CONTRIBUTORS.md)
- [General information](https://github.com/rocm/composable_kernel/blob/develop/README.md)
- [ROCm documentation](https://rocm.docs.amd.com/en/latest/how-to/llm-fine-tuning-optimization/optimizing-with-composable-kernel.html)
\ No newline at end of file
When creating an issue, please check if a similar issue already exists.
### When reporting a bug, please include:
- [ ] A descriptive title
- [ ] An isolated way to reproduce the behavior (preferably a docker container with a repro)
- [ ] ROCm version, clang version, Composable Kernel commit pin
- [ ] Environment variables
- [ ] The behavior you expect to see, and the behavior you actually see
### When requesting a feature, please include:
- [ ] A descriptive title
- [ ] A detailed description of the problem you are trying to solve
- [ ] An overview of the suggested solution
- [ ] Explanation why the solution is an improvement
\ No newline at end of file
## Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
## Checklist
Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.
- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
- [ ] I have added inline documentation which enables the maintainers with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant after this pull request
- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged
## Discussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
...@@ -183,10 +183,12 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") ...@@ -183,10 +183,12 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
message("Enabling XDL instances") message("Enabling XDL instances")
add_definitions(-DCK_USE_XDL) add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
message("Enabling FP8 gemms on native architectures") message("Enabling FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94) add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx95") if (SUPPORTED_GPU_TARGETS MATCHES "gfx95")
add_definitions(-DCK_USE_AMD_MFMA_GFX950) add_definitions(-DCK_USE_AMD_MFMA_GFX950)
...@@ -194,6 +196,7 @@ endif() ...@@ -194,6 +196,7 @@ endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances") message("Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA) add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif() endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950")
add_definitions(-DCK_USE_OCP_FP8) add_definitions(-DCK_USE_OCP_FP8)
...@@ -207,6 +210,7 @@ endif() ...@@ -207,6 +210,7 @@ endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH) add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
set(CK_USE_FP8_ON_UNSUPPORTED_ARCH "ON")
endif() endif()
# CK config file to record supported datatypes, etc. # CK config file to record supported datatypes, etc.
......
FROM ubuntu:20.04 FROM ubuntu:22.04
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=6.3 ARG ROCMVERSION=6.3
ARG compiler_version="" ARG compiler_version=""
...@@ -48,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -48,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev \ libnuma-dev \
libpthread-stubs0-dev \ libpthread-stubs0-dev \
llvm-amdgpu \ llvm-amdgpu \
mpich \
net-tools \ net-tools \
pkg-config \ pkg-config \
python \ python \
...@@ -63,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -63,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano \ nano \
zlib1g-dev \ zlib1g-dev \
zip \ zip \
libzstd-dev \
openssh-server \ openssh-server \
clang-format-12 \ clang-format-12 \
kmod && \ kmod && \
...@@ -70,7 +72,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -70,7 +72,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rm -rf /var/lib/apt/lists/* && \ rm -rf /var/lib/apt/lists/* && \
rm -rf amdgpu-install* && \ rm -rf amdgpu-install* && \
# Remove unnecessary rocm components that take a lot of space # Remove unnecessary rocm components that take a lot of space
apt-get remove -y rocblas rocfft rocsparse composablekernel-dev apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt
# Update the cmake to version 3.27.5 # Update the cmake to version 3.27.5
RUN pip install --upgrade cmake==3.27.5 && \ RUN pip install --upgrade cmake==3.27.5 && \
...@@ -92,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \ ...@@ -92,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \
dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \
# Install packages for processing the performance results # Install packages for processing the performance results
pip3 install --upgrade pip && \ pip3 install --upgrade pip && \
pip3 install sqlalchemy==1.4.46 pymysql pandas==2.0.3 setuptools-rust sshtunnel==0.4.0 && \ pip3 install sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust sshtunnel==0.4.0 && \
# Add render group # Add render group
groupadd -f render && \ groupadd -f render && \
# Install the new rocm-cmake version # Install the new rocm-cmake version
......
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub20.04_rocm6.3" ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.3"
FROM $BASE_DOCKER FROM $BASE_DOCKER
ARG compiler_version="" ARG compiler_version=""
ARG compiler_commit="" ARG compiler_commit=""
......
...@@ -40,10 +40,10 @@ def getBaseDockerImageName(){ ...@@ -40,10 +40,10 @@ def getBaseDockerImageName(){
else{ else{
def ROCM_numeric = "${params.ROCMVERSION}" as float def ROCM_numeric = "${params.ROCMVERSION}" as float
if ( ROCM_numeric < 6.4 ){ if ( ROCM_numeric < 6.4 ){
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}"
} }
else{ else{
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}" img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm${params.ROCMVERSION}"
} }
} }
return img return img
...@@ -357,7 +357,7 @@ def buildHipClangJob(Map conf=[:]){ ...@@ -357,7 +357,7 @@ def buildHipClangJob(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) { if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
...@@ -426,7 +426,7 @@ def Build_CK(Map conf=[:]){ ...@@ -426,7 +426,7 @@ def Build_CK(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm") def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group // Jenkins is complaining about the render group
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) { if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 " dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
} }
...@@ -566,11 +566,9 @@ def Build_CK(Map conf=[:]){ ...@@ -566,11 +566,9 @@ def Build_CK(Map conf=[:]){
ls -ltr ls -ltr
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install" CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install"
cmake --build build -- -j cmake --build build -- -j
ctest --test-dir build
""" """
} }
dir("hipTensor-${params.hipTensor_branch}/build"){
sh 'ctest'
}
} }
} }
} }
...@@ -752,8 +750,8 @@ pipeline { ...@@ -752,8 +750,8 @@ pipeline {
description: "Test building instances for various architectures simultaneously (default: OFF)") description: "Test building instances for various architectures simultaneously (default: OFF)")
booleanParam( booleanParam(
name: "BUILD_GFX12", name: "BUILD_GFX12",
defaultValue: false, defaultValue: true,
description: "Build CK and run tests on gfx12 (default: OFF)") description: "Build CK and run tests on gfx12 (default: ON)")
booleanParam( booleanParam(
name: "NINJA_BUILD_TRACE", name: "NINJA_BUILD_TRACE",
defaultValue: false, defaultValue: false,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <hip/hip_runtime_api.h> #include <hip/hip_runtime_api.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <stdexcept>
namespace rtc { namespace rtc {
......
rocm-docs-core==1.11.0 rocm-docs-core==1.12.0
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.11.0 rocm-docs-core==1.12.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
...@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD ...@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1 8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM 0, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN 0, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
......
...@@ -2,10 +2,17 @@ ...@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation # generate kernel instances to speed up compilation
DTYPE_MAP = { FWD_DTYPE_MAP = {
"fp16": "ck_tile::fp16_t", "fp16" : "FmhaFwdFp16",
"bf16": "ck_tile::bf16_t", "bf16" : "FmhaFwdBf16",
"fp8" : "ck_tile::fp8_t" "fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
} }
MASK_IMPL = { MASK_IMPL = {
......
...@@ -283,7 +283,7 @@ class FmhaBwdApiPool: ...@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic]) F_deterministic=BOOL_MAP[trait.deterministic])
...@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel: ...@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format( FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
...@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaBwdApiPool(mask_impl) api_pool = FmhaBwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel: ...@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format( FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad], F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad], F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode], F_mode = MODE_MAP[self.F_mode],
...@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]: ...@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel: ...@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format( FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0, F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0, F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad], F_spad = BOOL_MAP[self.F_spad],
...@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]: ...@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -282,7 +282,7 @@ class FmhaFwdApiPool: ...@@ -282,7 +282,7 @@ class FmhaFwdApiPool:
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -301,7 +301,7 @@ class FmhaFwdTileSize: ...@@ -301,7 +301,7 @@ class FmhaFwdTileSize:
F_bk1 : int # tile size along kv gemm unroll F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used) F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v F_rn1 : int # number of warps for gemm1 along head dim v
...@@ -339,7 +339,7 @@ class FmhaFwdKernel: ...@@ -339,7 +339,7 @@ class FmhaFwdKernel:
FMHA_FWD_KERNEL_BODY.format( FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
...@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# no need lse/dropout kernels # no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else: else:
assert False assert False
return pipelines return pipelines
...@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm ...@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
gen = list() gen = list()
api_pool = FmhaFwdApiPool(mask_impl) api_pool = FmhaFwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype) d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool: ...@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel: ...@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY.format( FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs, F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk, F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd, F_bd = self.F_tile.F_bd,
...@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported # rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f')) pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else: else:
assert False assert False
return pipelines return pipelines
...@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl) api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype) d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) ...@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}} }}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>; {F_dvpad}>;
#include <iostream> #include <iostream>
...@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem< ...@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType, typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim}, {F_hdim},
{F_bm0}, {F_bm0},
{F_bn1}, {F_bn1},
{F_mode}, {F_mode},
fmha_trait>; fmha_trait>;
...@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a ...@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
if(s.log_level_ > 0) if(s.log_level_ > 0)
std::cout std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>() << ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>() << ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush; << std::flush;
return ck_tile::launch_kernel(s, return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }}, [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }} [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
); );
}} }}
...@@ -431,11 +431,11 @@ class FmhaFwdSplitKVApiPool: ...@@ -431,11 +431,11 @@ class FmhaFwdSplitKVApiPool:
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if' if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if' if_i = 'if' if i == 0 else 'else if'
...@@ -472,7 +472,7 @@ class FmhaFwdSplitKVKernel: ...@@ -472,7 +472,7 @@ class FmhaFwdSplitKVKernel:
FMHA_FWD_SPLITKV_KERNEL_BODY.format( FMHA_FWD_SPLITKV_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0, F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0, F_bk0 = self.F_tile.F_bk0,
...@@ -492,7 +492,7 @@ class FmhaFwdSplitKVKernel: ...@@ -492,7 +492,7 @@ class FmhaFwdSplitKVKernel:
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
...@@ -552,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel: ...@@ -552,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel:
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format( FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
F_idx = self.F_idx, F_idx = self.F_idx,
F_hdim = self.F_hdim, F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype], F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0, F_bm0 = self.F_tile.F_bm0,
F_bn1 = self.F_tile.F_bn1, F_bn1 = self.F_tile.F_bn1,
F_spad = BOOL_MAP[self.F_pipeline.F_spad], F_spad = BOOL_MAP[self.F_pipeline.F_spad],
...@@ -625,7 +625,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -625,7 +625,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
# TODO: use async pipeline when compiler is more stable # TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True: # if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
...@@ -644,6 +644,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -644,6 +644,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask)) pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else: else:
assert False assert False
return pipelines return pipelines
...@@ -651,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> ...@@ -651,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list() gen = list()
api_pool = FmhaFwdSplitKVApiPool(mask_impl) api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype) d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
...@@ -711,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis ...@@ -711,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
gen = list() gen = list()
for dtype in DTYPE_MAP.keys(): for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
if d == None: if d == None:
continue continue
......
...@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[]) ...@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
} }
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataTypeConfig>
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{ {
double rtol = 1e-2; double rtol = 1e-2;
...@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) ...@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
...@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_ ...@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
return ck_tile::make_tuple(rtol, atol); return ck_tile::make_tuple(rtol, atol);
} }
template <typename DataType> template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
...@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
using TypeConfig = FmhaBwdTypeConfig<DataType>; using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType; using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType; using KDataType = typename TypeConfig::KDataType;
...@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v); auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result, bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref, dq_host_ref,
std::string("Error: QGrad Incorrect results!"), std::string("Error: QGrad Incorrect results!"),
...@@ -986,11 +986,11 @@ int main(int argc, char* argv[]) ...@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec"); const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -14,11 +14,19 @@ ...@@ -14,11 +14,19 @@
#include <utility> #include <utility>
#include <variant> #include <variant>
struct FmhaBwdFp16
{
};
struct FmhaBwdBf16
{
};
template <typename DataType> template <typename DataType>
struct FmhaBwdTypeConfig; struct FmhaBwdTypeConfig;
template <> template <>
struct FmhaBwdTypeConfig<ck_tile::half_t> struct FmhaBwdTypeConfig<FmhaBwdFp16>
{ {
using QDataType = ck_tile::half_t; using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t; using KDataType = ck_tile::half_t;
...@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t> ...@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
}; };
template <> template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t> struct FmhaBwdTypeConfig<FmhaBwdBf16>
{ {
using QDataType = ck_tile::bf16_t; using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include "fmha_fwd.hpp" #include "fmha_fwd.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "rotary.hpp" #include "rotary.hpp"
#include "utils.hpp" #include "utils.hpp"
...@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v) ...@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not") arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
.insert("mode", "0", "kernel mode. 0:batch, 1:group") .insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size") .insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q") .insert("h", "8", "num of head, for q")
...@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[]) ...@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
} }
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataTypeConfig>
auto get_elimit(std::string /*init_method*/) auto get_elimit(std::string /*init_method*/)
{ {
double rtol = 1e-3; double rtol = 1e-3;
...@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/) ...@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/) auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
{ {
double rtol = 1e-2; double rtol = 1e-2;
double atol = 1e-2; double atol = 1e-2;
...@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/) ...@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
} }
template <> template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method) auto get_elimit<FmhaFwdFp8>(std::string init_method)
{ {
if(init_method == "ui" || init_method == "ni") if(init_method == "ui" || init_method == "ni")
{ {
...@@ -261,7 +262,7 @@ int override_num_splits_if_necessary( ...@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
return num_splits; return num_splits;
} }
template <typename DataType> template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
...@@ -305,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -305,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> || if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
std::is_same_v<DataType, ck_tile::bf16_t>)) std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
{ {
if(0 < rotary_dim) if(0 < rotary_dim)
{ {
...@@ -428,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -428,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
return atoi(squant_str.c_str()) != 0 ? true : false; return atoi(squant_str.c_str()) != 0 ? true : false;
}(); }();
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
scale_p = dtype_max / range_p;
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
scale_o = range_p * range_v / range_o / dtype_max;
}
std::string vlayout = arg_parser.get_str("vlayout"); std::string vlayout = arg_parser.get_str("vlayout");
bool lse = arg_parser.get_bool("lse"); bool lse = arg_parser.get_bool("lse");
...@@ -466,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -466,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
bool s_randval = false; bool s_randval = false;
if(p_drop > 0.0f && do_validation) if(p_drop > 0.0f && do_validation != 0)
{ {
s_randval = true; s_randval = true;
} }
...@@ -499,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -499,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
using TypeConfig = FmhaFwdTypeConfig<DataType>; using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType; using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType; using KDataType = typename TypeConfig::KDataType;
...@@ -513,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -513,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using OaccDataType = typename TypeConfig::OaccDataType; using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType; using ODataType = typename TypeConfig::ODataType;
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float p_dtype_max = v_dtype_max; // assume p and v is the same type
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
scale_p = p_dtype_max / range_p;
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
}
// accumulation numbers for performance evaluation // accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q = auto max_seqlen_q =
...@@ -709,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -709,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(init_method == "ufq" || init_method == "uf:q" || else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization init_method == "3") // suitable for fp8 quantization
{ {
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host); ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host); ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host); ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host); ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host); ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(vnew_host);
// bias_fp8 = qscale_bias * bias_fp32 // bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
// Assume bias is in [-1.f, 1.f] in original fp32 // Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
} }
...@@ -1118,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1118,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush; << " GB/s" << std::flush;
if(!do_validation) if(do_validation == 0)
{ {
std::cout << std::flush << std::endl; std::cout << std::flush << std::endl;
return true; return true;
} }
if(do_validation == 2)
{
// NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type;
naive_t.k_type = data_type;
naive_t.v_type = data_type;
naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO?
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
ck_tile::naive_attention_fwd_args naive_a;
naive_a.q_ptr = q_buf.GetDeviceBuffer();
naive_a.k_ptr = k_buf.GetDeviceBuffer();
naive_a.v_ptr = v_buf.GetDeviceBuffer();
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
naive_a.scale_s = scale_s;
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
naive_a.page_table_ptr =
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
naive_a.hdim = hdim_q;
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
naive_a.batch_q = batch;
naive_a.batch_kv = batch;
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
naive_a.seqlen_q = seqlen_qs[0];
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
naive_a.nhead_q = nhead;
naive_a.nhead_kv = nhead_k;
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
ck_tile::stream_config naive_s{};
naive_attention_fwd(naive_t, naive_a, naive_s);
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
return pass_;
}
o_buf.FromDevice(o_host.data()); o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data()); lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data()); randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() { auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::scales{scale_p}; return ck_tile::scales{scale_p};
else else
return ck_tile::identity{}; return ck_tile::identity{};
}(); }();
auto oacc_element_func = [&]() { auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>) if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{}, return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o}); ck_tile::scales{scale_o});
else else
...@@ -1186,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1186,7 +1240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths()); decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] = auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q); slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
ck_tile::reference_batched_rotary_position_embedding( ck_tile::reference_batched_rotary_position_embedding(
...@@ -1202,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1202,13 +1256,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
k_host_ref.ForEach([&](auto& self, auto i) { k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]); self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
}); });
} else { } else {
k_host_ref.ForEach([&](auto& self, auto i) { k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
}); });
} }
} else } else
#endif #endif
{ {
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); }); if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
...@@ -1229,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1229,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
knew_host_ref_ro.emplace(knew_host_ref.get_lengths()); knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] = auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew); slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
ck_tile::reference_batched_rotary_position_embedding( ck_tile::reference_batched_rotary_position_embedding(
...@@ -1251,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1251,19 +1305,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(0 < page_block_size) { if(0 < page_block_size) {
if(is_v_rowmajor) { if(is_v_rowmajor) {
if(i_perm) { if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
}); });
} else { } else {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
}); });
} }
} }
else else
{ {
if(i_perm) { if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) { v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
}); });
} else { } else {
...@@ -1458,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -1458,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on // clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method); auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err( bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass; pass &= cur_pass;
...@@ -1515,15 +1569,15 @@ int main(int argc, char* argv[]) ...@@ -1515,15 +1569,15 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec"); const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16") if(data_type == "fp16")
{ {
return run<ck_tile::half_t>(arg_parser) ? 0 : -2; return run<FmhaFwdFp16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "bf16") else if(data_type == "bf16")
{ {
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2; return run<FmhaFwdBf16>(arg_parser) ? 0 : -2;
} }
else if(data_type == "fp8") else if(data_type == "fp8")
{ {
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2; return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
} }
return -3; return -3;
......
...@@ -16,11 +16,35 @@ ...@@ -16,11 +16,35 @@
#include <utility> #include <utility>
#include <variant> #include <variant>
struct FmhaFwdFp16
{
};
struct FmhaFwdBf16
{
};
struct FmhaFwdFp8
{
};
struct FmhaFwdBf8
{
};
struct FmhaFwdFp8Fp16
{
};
struct FmhaFwdFp8Bf16
{
};
template <typename DataType> template <typename DataType>
struct FmhaFwdTypeConfig; struct FmhaFwdTypeConfig;
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::half_t> struct FmhaFwdTypeConfig<FmhaFwdFp16>
{ {
using QDataType = ck_tile::half_t; using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t; using KDataType = ck_tile::half_t;
...@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t> ...@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf16_t> struct FmhaFwdTypeConfig<FmhaFwdBf16>
{ {
using QDataType = ck_tile::bf16_t; using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t;
...@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t> ...@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::fp8_t> struct FmhaFwdTypeConfig<FmhaFwdFp8>
{ {
using QDataType = ck_tile::fp8_t; using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t;
...@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t> ...@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t> struct FmhaFwdTypeConfig<FmhaFwdBf8>
{ {
using QDataType = ck_tile::bf8_t; using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment