"vscode:/vscode.git/clone" did not exist on "01a5d18a537b65a156cfa1a77706693a24c869c1"
Commit f3ff55b6 authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents 5e93fa9e 689a5ae4
We'd love for you to contribute to our source code!
Some helpful links:
- [Code of Conduct guidelines](https://www.contributor-covenant.org/version/2/1/code_of_conduct/code_of_conduct.txt)
- [New issue guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/ISSUE_TEMPLATE.md)
- [Submitting a pull request guidelines](https://github.com/rocm/composable_kernel/blob/develop/.github/PULL_REQUEST_TEMPLATE.md)
- [Maintainers](https://github.com/rocm/composable_kernel/blob/develop/CONTRIBUTORS.md)
- [General information](https://github.com/rocm/composable_kernel/blob/develop/README.md)
- [ROCm documentation](https://rocm.docs.amd.com/en/latest/how-to/llm-fine-tuning-optimization/optimizing-with-composable-kernel.html)
\ No newline at end of file
When creating an issue, please check if a similar issue already exists.
### When reporting a bug, please include:
- [ ] A descriptive title
- [ ] An isolated way to reproduce the behavior (preferably a docker container with a repro)
- [ ] ROCm version, clang version, Composable Kernel commit pin
- [ ] Environment variables
- [ ] The behavior you expect to see, and the behavior you actually see
### When requesting a feature, please include:
- [ ] A descriptive title
- [ ] A detailed description of the problem you are trying to solve
- [ ] An overview of the suggested solution
- [ ] Explanation why the solution is an improvement
\ No newline at end of file
## Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
## Checklist
Please put an `x` into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.
- [ ] I have added tests relevant to the introduced functionality, and the unit tests are passing locally
- [ ] I have added inline documentation which enables the maintainers with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant after this pull request
- [ ] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
- [ ] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged
## Discussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
......@@ -183,14 +183,17 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9")
message("Enabling XDL instances")
add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94")
message("Enabling FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message("Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
add_definitions(-DCK_USE_OCP_FP8)
......@@ -204,6 +207,7 @@ endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
set(CK_USE_FP8_ON_UNSUPPORTED_ARCH "ON")
endif()
# CK config file to record supported datatypes, etc.
......
FROM ubuntu:20.04
FROM ubuntu:22.04
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=6.3
ARG compiler_version=""
......@@ -48,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev \
libpthread-stubs0-dev \
llvm-amdgpu \
mpich \
net-tools \
pkg-config \
python \
......@@ -63,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano \
zlib1g-dev \
zip \
libzstd-dev \
openssh-server \
clang-format-12 \
kmod && \
......@@ -70,7 +72,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rm -rf /var/lib/apt/lists/* && \
rm -rf amdgpu-install* && \
# Remove unnecessary rocm components that take a lot of space
apt-get remove -y rocblas rocfft rocsparse composablekernel-dev
apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt
# Update the cmake to version 3.27.5
RUN pip install --upgrade cmake==3.27.5 && \
......@@ -92,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \
dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \
# Install packages for processing the performance results
pip3 install --upgrade pip && \
pip3 install sqlalchemy==1.4.46 pymysql pandas==2.0.3 setuptools-rust sshtunnel==0.4.0 && \
pip3 install sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust sshtunnel==0.4.0 && \
# Add render group
groupadd -f render && \
# Install the new rocm-cmake version
......
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub20.04_rocm6.3"
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.3"
FROM $BASE_DOCKER
ARG compiler_version=""
ARG compiler_commit=""
......
......@@ -40,10 +40,10 @@ def getBaseDockerImageName(){
else{
def ROCM_numeric = "${params.ROCMVERSION}" as float
if ( ROCM_numeric < 6.4 ){
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}"
}
else{
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}"
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm${params.ROCMVERSION}"
}
}
return img
......@@ -357,7 +357,7 @@ def buildHipClangJob(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -426,7 +426,7 @@ def Build_CK(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -566,11 +566,9 @@ def Build_CK(Map conf=[:]){
ls -ltr
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install"
cmake --build build -- -j
ctest --test-dir build
"""
}
dir("hipTensor-${params.hipTensor_branch}/build"){
sh 'ctest'
}
}
}
}
......@@ -752,8 +750,8 @@ pipeline {
description: "Test building instances for various architectures simultaneously (default: OFF)")
booleanParam(
name: "BUILD_GFX12",
defaultValue: false,
description: "Build CK and run tests on gfx12 (default: OFF)")
defaultValue: true,
description: "Build CK and run tests on gfx12 (default: ON)")
booleanParam(
name: "NINJA_BUILD_TRACE",
defaultValue: false,
......
......@@ -4,6 +4,7 @@
#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
#include <stdexcept>
namespace rtc {
......
rocm-docs-core==1.11.0
rocm-docs-core==1.12.0
sphinxcontrib-bibtex==2.6.3
......@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.11.0
rocm-docs-core==1.12.0
# via -r requirements.in
six==1.16.0
# via pybtex
......
......@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
0, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
0, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
......
......@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}
MASK_IMPL = {
......
......@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
......@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
......@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list()
api_pool = FmhaBwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
......@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode],
......@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen = list()
for dtype in DTYPE_MAP.keys():
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
......@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad],
......@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen = list()
for dtype in DTYPE_MAP.keys():
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -282,7 +282,7 @@ class FmhaFwdApiPool:
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
......@@ -339,7 +339,7 @@ class FmhaFwdKernel:
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
......@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
......@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
......@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
......@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
......@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -435,7 +435,7 @@ class FmhaFwdSplitKVApiPool:
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
......@@ -472,7 +472,7 @@ class FmhaFwdSplitKVKernel:
FMHA_FWD_SPLITKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
......@@ -552,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel:
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn1 = self.F_tile.F_bn1,
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
......@@ -644,6 +644,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']:
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
......@@ -651,7 +654,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list()
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
continue
......@@ -711,7 +714,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
gen = list()
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
}
// different threshold for different dtype
template <typename DataType>
template <typename DataTypeConfig>
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-2;
......@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
}
template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
......@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataType>
template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
......@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
using TypeConfig = FmhaBwdTypeConfig<DataType>;
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
......@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref,
std::string("Error: QGrad Incorrect results!"),
......@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
}
return -3;
......
......@@ -14,11 +14,19 @@
#include <utility>
#include <variant>
struct FmhaBwdFp16
{
};
struct FmhaBwdBf16
{
};
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<ck_tile::half_t>
struct FmhaBwdTypeConfig<FmhaBwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
......@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
};
template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
struct FmhaBwdTypeConfig<FmhaBwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
......
......@@ -3,6 +3,7 @@
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
......@@ -41,7 +42,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
......@@ -142,7 +143,7 @@ auto create_args(int argc, char* argv[])
}
// different threshold for different dtype
template <typename DataType>
template <typename DataTypeConfig>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
......@@ -151,7 +152,7 @@ auto get_elimit(std::string /*init_method*/)
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
......@@ -159,7 +160,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
auto get_elimit<FmhaFwdFp8>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
......@@ -261,7 +262,7 @@ int override_num_splits_if_necessary(
return num_splits;
}
template <typename DataType>
template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
......@@ -305,8 +306,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
std::is_same_v<DataType, ck_tile::bf16_t>))
if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
{
if(0 < rotary_dim)
{
......@@ -428,25 +429,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
return atoi(squant_str.c_str()) != 0 ? true : false;
}();
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
scale_p = dtype_max / range_p;
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
scale_o = range_p * range_v / range_o / dtype_max;
}
std::string vlayout = arg_parser.get_str("vlayout");
bool lse = arg_parser.get_bool("lse");
......@@ -466,7 +448,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
bool s_randval = false;
if(p_drop > 0.0f && do_validation)
if(p_drop > 0.0f && do_validation != 0)
{
s_randval = true;
}
......@@ -499,7 +481,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
using TypeConfig = FmhaFwdTypeConfig<DataType>;
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
......@@ -513,6 +495,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType;
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float p_dtype_max = v_dtype_max; // assume p and v is the same type
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
scale_p = p_dtype_max / range_p;
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
}
// accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q =
......@@ -709,14 +713,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization
{
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host);
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(vnew_host);
// bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
// Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
}
......@@ -1118,25 +1122,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush;
if(!do_validation)
if(do_validation == 0)
{
std::cout << std::flush << std::endl;
return true;
}
if(do_validation == 2)
{
// NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type;
naive_t.k_type = data_type;
naive_t.v_type = data_type;
naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO?
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
ck_tile::naive_attention_fwd_args naive_a;
naive_a.q_ptr = q_buf.GetDeviceBuffer();
naive_a.k_ptr = k_buf.GetDeviceBuffer();
naive_a.v_ptr = v_buf.GetDeviceBuffer();
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
naive_a.scale_s = scale_s;
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
naive_a.page_table_ptr =
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
naive_a.hdim = hdim_q;
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
naive_a.batch_q = batch;
naive_a.batch_kv = batch;
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
naive_a.seqlen_q = seqlen_qs[0];
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
naive_a.nhead_q = nhead;
naive_a.nhead_kv = nhead_k;
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
ck_tile::stream_config naive_s{};
naive_attention_fwd(naive_t, naive_a, naive_s);
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
return pass_;
}
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::scales{scale_p};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else
......@@ -1458,7 +1512,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass;
......@@ -1515,15 +1569,15 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
return run<FmhaFwdFp16>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
return run<FmhaFwdBf16>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
}
return -3;
......
......@@ -16,11 +16,35 @@
#include <utility>
#include <variant>
struct FmhaFwdFp16
{
};
struct FmhaFwdBf16
{
};
struct FmhaFwdFp8
{
};
struct FmhaFwdBf8
{
};
struct FmhaFwdFp8Fp16
{
};
struct FmhaFwdFp8Bf16
{
};
template <typename DataType>
struct FmhaFwdTypeConfig;
template <>
struct FmhaFwdTypeConfig<ck_tile::half_t>
struct FmhaFwdTypeConfig<FmhaFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
......@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
struct FmhaFwdTypeConfig<FmhaFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
......@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::fp8_t>
struct FmhaFwdTypeConfig<FmhaFwdFp8>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
......@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
struct FmhaFwdTypeConfig<FmhaFwdBf8>
{
using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment