Commit 1d784873 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents d25889b1 851c3ed1
* @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex * @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
# Documentation files # Documentation files
docs/* @ROCm/rocm-documentation docs/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
*.md @ROCm/rocm-documentation *.md @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
*.rst @ROCm/rocm-documentation *.rst @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
.readthedocs.yaml @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
# Header directory for Doxygen documentation # Header directory for Doxygen documentation
library/include/* @ROCm/rocm-documentation library/include/* @ROCm/rocm-documentation @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex
...@@ -15,4 +15,4 @@ python: ...@@ -15,4 +15,4 @@ python:
build: build:
os: ubuntu-22.04 os: ubuntu-22.04
tools: tools:
python: "3.8" python: "3.10"
...@@ -660,7 +660,8 @@ def process_results(Map conf=[:]){ ...@@ -660,7 +660,8 @@ def process_results(Map conf=[:]){
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;COMPILER_VERSION= CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;COMPILER_VERSION=
0 21 * * * % ROCMVERSION=6.1;COMPILER_VERSION=;COMPILER_COMMIT= 0 21 * * * % ROCMVERSION=6.1;COMPILER_VERSION=;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false''' : "" 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : ""
pipeline { pipeline {
agent none agent none
...@@ -727,6 +728,10 @@ pipeline { ...@@ -727,6 +728,10 @@ pipeline {
name: "RUN_CODEGEN_TESTS", name: "RUN_CODEGEN_TESTS",
defaultValue: true, defaultValue: true,
description: "Run the codegen tests (default: ON)") description: "Run the codegen tests (default: ON)")
booleanParam(
name: "BUILD_INSTANCES_ONLY",
defaultValue: false,
description: "Test building instances for various architectures simultaneously (default: OFF)")
} }
environment{ environment{
dbuser = "${dbuser}" dbuser = "${dbuser}"
...@@ -809,22 +814,22 @@ pipeline { ...@@ -809,22 +814,22 @@ pipeline {
{ {
parallel parallel
{ {
stage("Run Codegen Tests on MI100/MI200") stage("Run Codegen Tests on MI200")
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.RUN_CODEGEN_TESTS.toBoolean() } expression { params.RUN_CODEGEN_TESTS.toBoolean() }
} }
options { retry(2) } options { retry(2) }
agent{ label rocmnode("gfx908 || gfx90a")} agent{ label rocmnode("gfx90a")}
environment{ environment{
setup_args = "NO_CK_BUILD" setup_args = "NO_CK_BUILD"
execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \
cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx908;gfx90a" \ -D GPU_TARGETS="gfx90a" \
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j check""" -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check"""
} }
steps{ steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
...@@ -837,23 +842,23 @@ pipeline { ...@@ -837,23 +842,23 @@ pipeline {
{ {
parallel parallel
{ {
stage("Build CK and run Tests on MI100/MI200/MI300") stage("Build CK for all gfx9 targets")
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.RUN_FULL_QA.toBoolean() } expression { params.RUN_FULL_QA.toBoolean() }
} }
agent{ label rocmnode("gfx908 || gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
-DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
...@@ -868,43 +873,63 @@ pipeline { ...@@ -868,43 +873,63 @@ pipeline {
} }
agent{ label rocmnode("gfx942") } agent{ label rocmnode("gfx942") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx942" \ -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs() cleanWs()
} }
} }
stage("Build CK and run Tests on MI100/MI200") stage("Build CK and run Tests on MI200")
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
} }
agent{ label rocmnode("gfx908 || gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx908;gfx90a" \ -DGPU_TARGETS="gfx908;gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
} }
steps{ steps{
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
cleanWs() cleanWs()
} }
} }
stage("Build CK instances for different targets")
{
when {
beforeAgent true
expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() }
}
agent{ label rocmnode("gfx90a") }
environment{
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a;gfx1030;gfx1101" \
-D INSTANCES_ONLY=ON \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j32 """
}
steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
cleanWs()
}
}
stage("Build CK and run Tests on Navi21") stage("Build CK and run Tests on Navi21")
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
} }
agent{ label rocmnode("navi21") } agent{ label rocmnode("navi21") }
environment{ environment{
...@@ -924,7 +949,7 @@ pipeline { ...@@ -924,7 +949,7 @@ pipeline {
{ {
when { when {
beforeAgent true beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() } expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() }
} }
agent{ label rocmnode("navi32") } agent{ label rocmnode("navi32") }
environment{ environment{
...@@ -947,27 +972,11 @@ pipeline { ...@@ -947,27 +972,11 @@ pipeline {
{ {
parallel parallel
{ {
stage("Run ckProfiler: gfx90*")
{
when {
beforeAgent true
expression { !params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() }
}
options { retry(2) }
agent{ label rocmnode("gfx908 || gfx90a")}
environment{
setup_args = """ -DGPU_TARGETS="gfx908;gfx90a" -DBUILD_DEV=On """
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
}
}
stage("Run ckProfiler: gfx90a") stage("Run ckProfiler: gfx90a")
{ {
when { when {
beforeAgent true beforeAgent true
expression { params.RUN_FULL_QA.toBoolean() && params.RUN_PERFORMANCE_TESTS.toBoolean() } expression { params.RUN_PERFORMANCE_TESTS.toBoolean() }
} }
options { retry(2) } options { retry(2) }
agent{ label rocmnode("gfx90a")} agent{ label rocmnode("gfx90a")}
......
rocm-docs-core==0.38.1 rocm-docs-core==1.1.1
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
# #
# This file is autogenerated by pip-compile with Python 3.8 # This file is autogenerated by pip-compile with Python 3.10
# by the following command: # by the following command:
# #
# pip-compile requirements.in # pip-compile requirements.in
...@@ -48,12 +48,6 @@ idna==3.4 ...@@ -48,12 +48,6 @@ idna==3.4
# via requests # via requests
imagesize==1.4.1 imagesize==1.4.1
# via sphinx # via sphinx
importlib-metadata==6.8.0
# via
# sphinx
# sphinxcontrib-bibtex
importlib-resources==6.1.0
# via rocm-docs-core
jinja2==3.1.2 jinja2==3.1.2
# via # via
# myst-parser # myst-parser
...@@ -99,8 +93,6 @@ pyjwt[crypto]==2.6.0 ...@@ -99,8 +93,6 @@ pyjwt[crypto]==2.6.0
# via pygithub # via pygithub
pynacl==1.5.0 pynacl==1.5.0
# via pygithub # via pygithub
pytz==2023.3.post1
# via babel
pyyaml==6.0 pyyaml==6.0
# via # via
# myst-parser # myst-parser
...@@ -111,7 +103,7 @@ requests==2.31.0 ...@@ -111,7 +103,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==0.38.1 rocm-docs-core==1.1.1
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
...@@ -165,7 +157,3 @@ urllib3==1.26.18 ...@@ -165,7 +157,3 @@ urllib3==1.26.18
# via requests # via requests
wrapt==1.15.0 wrapt==1.15.0
# via deprecated # via deprecated
zipp==3.17.0
# via
# importlib-metadata
# importlib-resources
add_custom_target(example_grouped_gemm_xdl_multi_abd) add_custom_target(example_grouped_gemm_xdl_multi_abd)
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16 grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16) add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_fp16)
add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp) add_example_executable(example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8 grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8) add_example_dependencies(example_grouped_gemm_xdl_multi_abd example_grouped_gemm_multi_abd_xdl_fixed_nk_bias_bf16_i8)
...@@ -17,7 +17,7 @@ add_custom_command( ...@@ -17,7 +17,7 @@ add_custom_command(
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
# not using add_example_executable() to add this target, since we don't want this to have # not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check" # to be included in "make all/install/check"
message("adding tile_example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
......
...@@ -30,27 +30,29 @@ args: ...@@ -30,27 +30,29 @@ args:
-mode kernel mode. 0:batch, 1:group (default:0) -mode kernel mode. 0:batch, 1:group (default:0)
-b batch size (default:2) -b batch size (default:2)
-h num of head, for q (default:8) -h num of head, for q (default:8)
-h_k num of head, for k/v, 0 means equal to h (default:0) -h_k num of head, for k/v, -1 means equal to h (default:-1)
if not equal to h, then this is GQA/MQA case if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
-s_k seqlen_k, 0 means equal to s (default:0) -s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128) -d head dim for q, k (default:128)
-d_v head dim for v, 0 means equal to d (default:0) -d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
note when squant=1, this value will be modified by range_q/k note when squant=1, this value will be modified by range_q/k
-range_q per-tensor quantization range of q. used if squant=1. (default:2) -range_q per-tensor quantization range of q. used if squant=1. (default:16)
-range_k per-tensor quantization range of k. used if squant=1. (default:2) -range_k per-tensor quantization range of k. used if squant=1. (default:16)
-range_v per-tensor quantization range of v. used if squant=1. (default:2) -range_v per-tensor quantization range of v. used if squant=1. (default:16)
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2) -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
-squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) -squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0)
1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,
scale_o according to range_q, range_k, range_v, range_p, range_o scale_o according to range_q, range_k, range_v, range_p, range_o
-iperm permute input (default:1) -iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1) -operm permute output (default:1)
-bias add bias or not (default:0) -bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
-prec data type. fp16/bf16/fp8/bf8 (default:fp16) -prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask 't', top-left causal mask, 'b', bottom-r causal mask
...@@ -59,11 +61,11 @@ args: ...@@ -59,11 +61,11 @@ args:
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
-lse 0 not store lse, 1 store lse (default:0) -lse 0 not store lse, 1 store lse (default:0)
-kname if set to 1 will print kernel name (default:0) -kname if set to 1 will print kernel name (default:0)
-init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) -init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1)
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
``` ```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
...@@ -85,6 +87,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov ...@@ -85,6 +87,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov
### attention bias ### attention bias
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
### alibi
alibi is supported
### lse ### lse
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
no_bias = 0,
elementwise_bias = 1,
alibi = 2,
};
struct bias_info
{
bias_enum type;
/*
* simple dispatch logic
*
* if type == elementwise_bias:
* if rank_info == 0:
* bias is 1*1*s*s
* elif rank_info == 1:
* bias is 1*h*s*s
* elif rank_info == 2:
* bias is b*h*s*s
*
* elif type == alibi:
* if rank_info == 0:
* alibi in 1*h
* elif rank_info == 1:
* alibi in b*h
*/
int rank_info;
void serialize(std::ostream& os) const
{
if(type == bias_enum::no_bias)
os << "n";
else if(type == bias_enum::elementwise_bias)
{
os << "e";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
else if(type == bias_enum::alibi)
{
os << "alibi";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
}
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
str.compare(0, 11, "elementwise") == 0)
{
info.type = bias_enum::elementwise_bias;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
str.compare(0, 5, "alibi") == 0)
{
info.type = bias_enum::alibi;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;
}
};
...@@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[]) ...@@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[])
.insert("b", "2", "batch size") .insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q") .insert("h", "8", "num of head, for q")
.insert("h_k", .insert("h_k",
"0", "-1",
"num of head, for k/v, 0 means equal to h\n" "num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case") "if not equal to h, then this is GQA/MQA case")
.insert("s", .insert("s",
"3328", "3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n" "seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
.insert("s_k", "0", "seqlen_k, 0 means equal to s") .insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k") .insert("d", "128", "head dim for q, k")
.insert("d_v", "0", "head dim for v, 0 means equal to d") .insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s", .insert("scale_s",
"0", "0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n" "scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
...@@ -71,7 +71,11 @@ auto create_args(int argc, char* argv[]) ...@@ -71,7 +71,11 @@ auto create_args(int argc, char* argv[])
"permute input\n" "permute input\n"
"if true, will be b*h*s*d, else b*s*h*d") "if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output") .insert("operm", "1", "permute output")
.insert("bias", "0", "add bias or not") .insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("mask", .insert("mask",
"0", "0",
...@@ -153,7 +157,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -153,7 +157,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
if(nhead_k == 0) if(nhead_k < 0)
nhead_k = nhead; nhead_k = nhead;
if(nhead % nhead_k != 0) if(nhead % nhead_k != 0)
...@@ -164,11 +168,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -164,11 +168,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t seqlen_q = arg_parser.get_int("s"); ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
if(seqlen_k == 0) if(seqlen_k < 0)
seqlen_k = seqlen_q; seqlen_k = seqlen_q;
ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v == 0) if(hdim_v < 0)
hdim_v = hdim_q; hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
...@@ -208,9 +212,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -208,9 +212,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
} }
std::string vlayout = arg_parser.get_str("vlayout"); std::string vlayout = arg_parser.get_str("vlayout");
bool use_bias = arg_parser.get_bool("bias");
bool lse = arg_parser.get_bool("lse"); bool lse = arg_parser.get_bool("lse");
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
int init_method = arg_parser.get_int("init"); int init_method = arg_parser.get_int("init");
...@@ -295,12 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -295,12 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<VDataType> v_host( ck_tile::HostTensor<VDataType> v_host(
is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k));
// use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host
// will not be used for verification at all (but will be copied to device anyway).
ck_tile::HostTensor<BiasDataType> bias_host( ck_tile::HostTensor<BiasDataType> bias_host(
use_bias bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */); : std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<SaccDataType> alibi_slope_host(
bias.type == bias_enum::alibi
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q] // self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
ck_tile::HostTensor<LSEDataType> lse_host( ck_tile::HostTensor<LSEDataType> lse_host(
lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q} lse ? std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q}
...@@ -341,6 +351,24 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -341,6 +351,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Assume bias is in [-1.f, 1.f] in original fp32 // Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host); ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
} }
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
assert(slopes.size() == nhead);
if(bias.rank_info == 0)
{
// alibi in 1*h
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
}
else
{
// alibi in b*h
for(auto i_b = 0; i_b < batch; i_b++)
{
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
}
}
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
...@@ -350,6 +378,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -350,6 +378,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data()); q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data()); k_buf.ToDevice(k_host.data());
...@@ -357,6 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -357,6 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
bias_buf.ToDevice(bias_host.data()); bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data()); seqstart_k.ToDevice(seqstart_k_host.data());
alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off // clang-format off
auto layout_str = [&](bool permute){ auto layout_str = [&](bool permute){
...@@ -372,9 +402,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -372,9 +402,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
<< ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
<< ", mask:" << mask << ", v:" << vlayout << std::flush; << std::flush;
auto fmha_traits = fmha_fwd_traits{hdim_q, auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v, hdim_v,
...@@ -382,7 +412,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -382,7 +412,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mode == mode_enum::group, mode == mode_enum::group,
is_v_rowmajor, is_v_rowmajor,
mask.type, mask.type,
use_bias, bias.type,
lse, lse,
squant}; squant};
...@@ -441,7 +471,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -441,7 +471,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
return fmha_fwd_args{q_buf.GetDeviceBuffer(), return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(), bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(),
...@@ -461,7 +492,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -461,7 +492,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_q, stride_q,
stride_k, stride_k,
stride_v, stride_v,
stride_bias, bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
stride_o, stride_o,
nhead_stride_q, nhead_stride_q,
nhead_stride_k, nhead_stride_k,
...@@ -564,8 +596,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -564,8 +596,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{}, ck_tile::identity{},
ck_tile::scales(scale_s)); ck_tile::scales(scale_s));
if(use_bias) if(bias.type == bias_enum::elementwise_bias)
{ {
// elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k}); ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
// clang-format off // clang-format off
if(i_perm) if(i_perm)
...@@ -582,6 +615,51 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -582,6 +615,51 @@ bool run(const ck_tile::ArgParser& arg_parser)
SMPLComputeDataType>( SMPLComputeDataType>(
s_host_ref, bias_host_ref, s_host_ref); s_host_ref, bias_host_ref, s_host_ref);
} }
else if(bias.type == bias_enum::alibi)
{
// alibi construct elementwise bias to verify
auto alibi_host = [&]() {
if(mask.type != mask_enum::no_mask)
{
return ck_tile::make_alibi_from_lr_mask<SaccDataType, true>(
0,
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
}
else
{
return ck_tile::Alibi<SaccDataType, true>{
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL};
}
}();
ck_tile::HostTensor<SaccDataType> alibi_bias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k});
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
for(auto i_h = 0; i_h < nhead; i_h++)
{
SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
alibi_host.slope = current_slope;
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
{
SaccDataType pixel = 0;
alibi_host.update(pixel, i_r, i_c);
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
}
}
}
// [nhead, real_seqlen_q, real_seqlen_k]
ck_tile::reference_batched_elementwise<SMPLComputeDataType,
SaccDataType,
SMPLComputeDataType,
SMPLComputeDataType>(
s_host_ref, alibi_bias_host_ref, s_host_ref);
}
if(mask.type == mask_enum::no_mask) if(mask.type == mask_enum::no_mask)
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp" #include "mask.hpp"
#include "bias.hpp"
#include <type_traits> #include <type_traits>
template <typename DataType> template <typename DataType>
...@@ -86,7 +87,7 @@ struct fmha_fwd_args ...@@ -86,7 +87,7 @@ struct fmha_fwd_args
const void* q_ptr; const void* q_ptr;
const void* k_ptr; const void* k_ptr;
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; const void* bias_ptr; // bias or alibi_slope pointer
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
...@@ -106,7 +107,7 @@ struct fmha_fwd_args ...@@ -106,7 +107,7 @@ struct fmha_fwd_args
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
...@@ -219,7 +220,7 @@ template <ck_tile::index_t HDim_, ...@@ -219,7 +220,7 @@ template <ck_tile::index_t HDim_,
bool kIsVLayoutRowMajor_, bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_, ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_, typename FmhaMask_,
bool kHasBias_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_, bool kStoreLse_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kPadS_, bool kPadS_,
...@@ -240,7 +241,7 @@ struct fmha_fwd_traits_ ...@@ -240,7 +241,7 @@ struct fmha_fwd_traits_
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
...@@ -261,7 +262,7 @@ struct fmha_fwd_traits ...@@ -261,7 +262,7 @@ struct fmha_fwd_traits
bool is_group_mode; bool is_group_mode;
bool is_v_rowmajor; bool is_v_rowmajor;
mask_enum mask_type; mask_enum mask_type;
bool has_bias; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse; bool has_lse;
bool do_fp8_static_quant; bool do_fp8_static_quant;
// TODO: padding check is inside this api // TODO: padding check is inside this api
......
...@@ -40,6 +40,19 @@ MASK_MAP = { ...@@ -40,6 +40,19 @@ MASK_MAP = {
"generic" : "FmhaMasks::GenericMask" "generic" : "FmhaMasks::GenericMask"
} }
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
}
MODE_MAP = { MODE_MAP = {
"batch" : "false", "batch" : "false",
"group" : "true" "group" : "true"
...@@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = { ...@@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = {
"s_mask" : "t.mask_type != mask_enum::no_mask", "s_mask" : "t.mask_type != mask_enum::no_mask",
} }
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a); return fmha_fwd_<trait_>(s, a);
...@@ -213,7 +226,7 @@ class FmhaFwdApiTrait: ...@@ -213,7 +226,7 @@ class FmhaFwdApiTrait:
bk0blen : int bk0blen : int
vlayout : str vlayout : str
mask : str mask : str
bias : str # true/false bias : str #
lse : str # lse : str #
squant : str # squant : str #
spad : str spad : str
...@@ -241,8 +254,8 @@ class FmhaFwdApiTrait: ...@@ -241,8 +254,8 @@ class FmhaFwdApiTrait:
def skcheck(self) -> str: def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async': if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']: elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0' else : return f'a.seqlen_k % {self.bn0} == 0'
...@@ -297,7 +310,7 @@ class FmhaFwdPipeline: ...@@ -297,7 +310,7 @@ class FmhaFwdPipeline:
pn = pad_name() pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}' n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}' if pn != '' : n += f'_{pn}'
if self.F_bias == 't' : n += '_bias' if self.F_bias != 'no' : n += f'_{self.F_bias}'
if self.F_mask[0:2] == 's_': if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask' if self.F_mask == 's_mask': n += f'_mask'
else: else:
...@@ -332,7 +345,8 @@ class FmhaFwdApiPool: ...@@ -332,7 +345,8 @@ class FmhaFwdApiPool:
if_k = 'if' if k == 0 else 'else if' if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
...@@ -400,7 +414,7 @@ class FmhaFwdKernel: ...@@ -400,7 +414,7 @@ class FmhaFwdKernel:
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias], F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy, F_occupancy = self.F_tile.F_occupancy,
...@@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ ...@@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
} }
elif dtype == 'fp8' or dtype == 'bf8': elif dtype == 'fp8' or dtype == 'bf8':
return { return {
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1) '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
} }
else: else:
return None return None
...@@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw ...@@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
squant = 't' if dtype == 'fp8' else 'f' squant = 't' if dtype == 'fp8' else 'f'
pipelines = [] pipelines = []
if dtype in ['fp16', 'bf16']: if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
if hdim == 256: if hdim == 256:
# if True: # if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
...@@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw ...@@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']: elif dtype in ['fp8', 'bf8']:
# no need lse kernels # no need lse kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask))
else: else:
assert False assert False
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -149,11 +149,9 @@ struct mask_info ...@@ -149,11 +149,9 @@ struct mask_info
return tmp; return tmp;
} }
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
}; {
inline std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os); mi.serialize(os);
return os; return os;
} }
};
...@@ -17,7 +17,7 @@ for perm in 0 1 ; do ...@@ -17,7 +17,7 @@ for perm in 0 1 ; do
for vlayout in "r" "c" ; do for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do for lse in 0 1 ; do
for bias in 0 1 ; do for bias in "n" "e" "a"; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
...@@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b ...@@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done done
done done
...@@ -37,9 +38,11 @@ done ...@@ -37,9 +38,11 @@ done
done done
for perm in 0 1 ; do for perm in 0 1 ; do
for bias in 0 1 ; do for bias in "n" "e" "a" ; do
for b in 1 2 ; do for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done done
done done
done done
done
...@@ -236,6 +236,9 @@ ...@@ -236,6 +236,9 @@
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX #endif // CK_WORKAROUND_DENORM_FIX
// set flag to 1 to build deprecated instances
#define CK_BUILD_DEPRECATED 1
namespace ck { namespace ck {
enum struct InMemoryDataOperationEnum enum struct InMemoryDataOperationEnum
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <set> #include <set>
#include <vector>
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
......
...@@ -154,3 +154,8 @@ ...@@ -154,3 +154,8 @@
#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST #ifndef CK_TILE_USE_SUBDWORD_TILE_CAST
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif #endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif
...@@ -536,4 +536,15 @@ float log(float x) { return __logf(x); }; ...@@ -536,4 +536,15 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST CK_TILE_HOST
float log(float x) { return std::logf(x); }; float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
// TODO: this is hacky, we use u16
return __builtin_amdgcn_sad_u16(x, y, acc);
}
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
return (x > y ? (x - y) : (y - x)) + acc;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
#pragma once #pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockAttentionBiasEnum
{
NO_BIAS = 0,
ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale)
ALIBI = 2, // bias computed with position encoding, applied after scale
};
template <BlockAttentionBiasEnum>
struct BlockAttentionBiasEnumToStr;
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::NO_BIAS>
{
static constexpr const char* name = "";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ELEMENTWISE_BIAS>
{
static constexpr const char* name = "bias";
};
template <>
struct BlockAttentionBiasEnumToStr<BlockAttentionBiasEnum::ALIBI>
{
static constexpr const char* name = "alibi";
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment