Commit 401e643e authored by Po Yen Chen's avatar Po Yen Chen
Browse files

Merge branch 'develop' into feature/use-larger-tile-size-for-chunk-prefill

parents d783a8cf fdfe2102
FROM ubuntu:20.04
FROM ubuntu:22.04
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=6.2
ARG ROCMVERSION=6.3
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
......@@ -13,17 +13,12 @@ RUN set -xe && \
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \
curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
RUN if [ "$ROCMVERSION" != "6.3" ]; then \
sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.2.60200-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.2.60200-1_all.deb && \
RUN if [ "$ROCMVERSION" != "6.4" ]; then \
sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.3.60300-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.3.60300-1_all.deb && \
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \
elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3-20.04-1_all.deb --no-check-certificate" && \
apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3-20.04-1_all.deb && \
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3 rel-20 > /etc/apt/sources.list.d/rocm-build.list' && \
amdgpu-repo --amdgpu-build=2074281; \
fi
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" && \
......@@ -53,6 +48,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
libnuma-dev \
libpthread-stubs0-dev \
llvm-amdgpu \
mpich \
net-tools \
pkg-config \
python \
......@@ -68,6 +64,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
nano \
zlib1g-dev \
zip \
libzstd-dev \
openssh-server \
clang-format-12 \
kmod && \
......@@ -75,7 +72,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
rm -rf /var/lib/apt/lists/* && \
rm -rf amdgpu-install* && \
# Remove unnecessary rocm components that take a lot of space
apt-get remove -y rocblas rocfft rocsparse composablekernel-dev
apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt
# Update the cmake to version 3.27.5
RUN pip install --upgrade cmake==3.27.5 && \
......@@ -97,7 +94,7 @@ RUN pip install --upgrade cmake==3.27.5 && \
dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \
# Install packages for processing the performance results
pip3 install --upgrade pip && \
pip3 install sqlalchemy==1.4.46 pymysql pandas==2.0.3 setuptools-rust sshtunnel==0.4.0 && \
pip3 install sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust sshtunnel==0.4.0 && \
# Add render group
groupadd -f render && \
# Install the new rocm-cmake version
......
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub20.04_rocm6.2"
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.3"
FROM $BASE_DOCKER
ARG compiler_version=""
ARG compiler_commit=""
......
......@@ -38,13 +38,14 @@ def getBaseDockerImageName(){
img = "${params.USE_CUSTOM_DOCKER}"
}
else{
if (params.ROCMVERSION != "6.3"){
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
}
else{
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}"
def ROCM_numeric = "${params.ROCMVERSION}" as float
if ( ROCM_numeric < 6.4 ){
img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}"
}
else{
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm${params.ROCMVERSION}"
}
}
}
return img
}
......@@ -329,10 +330,8 @@ def cmake_build(Map conf=[:]){
try{
archiveArtifacts "perf_fmha_fwd_*.log"
archiveArtifacts "perf_fmha_bwd_*.log"
stash name: "perf_fmha_fwd_gfx942.log"
stash name: "perf_fmha_bwd_gfx942.log"
stash name: "perf_fmha_fwd_gfx90a.log"
stash name: "perf_fmha_bwd_gfx90a.log"
stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942"
stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a"
}
catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
......@@ -358,7 +357,7 @@ def buildHipClangJob(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -378,7 +377,7 @@ def buildHipClangJob(Map conf=[:]){
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 48, unit: 'HOURS')
timeout(time: 20, unit: 'HOURS')
{
cmake_build(conf)
}
......@@ -407,128 +406,6 @@ def buildHipClangJobAndReboot(Map conf=[:]){
}
}
def runCKProfiler(Map conf=[:]){
show_node_info()
env.HSA_ENABLE_SDMA=0
checkout scm
def image = getDockerImageName()
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
echo "Docker flags: ${dockerOpts}"
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES'){
sh 'rocminfo | tee rocminfo.log'
if ( !runShell('grep -n "gfx" rocminfo.log') ){
throw new Exception ("GPU not found")
}
else{
echo "GPU is OK"
}
}
}
}
catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){
echo "The job was cancelled or aborted"
throw e
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 24, unit: 'HOURS')
{
sh """
rm -rf build
mkdir build
"""
dir("build"){
unstash 'ckProfiler.tar.gz'
sh 'tar -xvf ckProfiler.tar.gz'
}
dir("script"){
if (params.RUN_FULL_QA){
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
archiveArtifacts "perf_batched_gemm.log"
archiveArtifacts "perf_grouped_gemm.log"
archiveArtifacts "perf_grouped_conv_fwd.log"
archiveArtifacts "perf_grouped_conv_bwd_data.log"
archiveArtifacts "perf_grouped_conv_bwd_weight.log"
archiveArtifacts "perf_gemm_bilinear.log"
archiveArtifacts "perf_reduction.log"
archiveArtifacts "perf_splitK_gemm.log"
archiveArtifacts "perf_onnx_gemm.log"
archiveArtifacts "perf_mixed_gemm.log"
// stash perf files to master
stash name: "perf_gemm.log"
stash name: "perf_resnet50_N256.log"
stash name: "perf_resnet50_N4.log"
stash name: "perf_batched_gemm.log"
stash name: "perf_grouped_gemm.log"
stash name: "perf_grouped_conv_fwd.log"
stash name: "perf_grouped_conv_bwd_data.log"
stash name: "perf_grouped_conv_bwd_weight.log"
stash name: "perf_gemm_bilinear.log"
stash name: "perf_reduction.log"
stash name: "perf_splitK_gemm.log"
stash name: "perf_onnx_gemm.log"
stash name: "perf_mixed_gemm.log"
//we will process results on the master node
}
else{
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
// stash perf files to master
stash name: "perf_gemm.log"
stash name: "perf_resnet50_N256.log"
stash name: "perf_resnet50_N4.log"
//we will process the results on the master node
}
}
}
}
}
return retimage
}
def runPerfTest(Map conf=[:]){
try{
runCKProfiler(conf)
}
catch(e){
echo "throwing error exception in performance tests"
echo 'Exception occurred: ' + e.toString()
throw e
}
finally{
if (!conf.get("no_reboot", false)) {
reboot()
}
}
}
def Build_CK(Map conf=[:]){
show_node_info()
......@@ -549,7 +426,7 @@ def Build_CK(Map conf=[:]){
def prefixpath = conf.get("prefixpath", "/opt/rocm")
// Jenkins is complaining about the render group
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
if (conf.get("enforce_xnack_on", false)) {
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
}
......@@ -572,7 +449,7 @@ def Build_CK(Map conf=[:]){
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 5, unit: 'MINUTES'){
timeout(time: 2, unit: 'MINUTES'){
sh 'rocminfo | tee rocminfo.log'
if ( !runShell('grep -n "gfx" rocminfo.log') ){
throw new Exception ("GPU not found")
......@@ -588,36 +465,95 @@ def Build_CK(Map conf=[:]){
throw e
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 24, unit: 'HOURS')
timeout(time: 20, unit: 'HOURS')
{
//check whether to run performance tests on this node
def do_perf_tests = 0
def arch_type = 0
sh 'rocminfo | tee rocminfo.log'
if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') || runShell('grep -n "gfx1201" rocminfo.log') || runShell('grep -n "gfx942" rocminfo.log') ){
do_perf_tests = 1
echo "Stash profiler and run performance tests"
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
arch_type = 1
}
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
arch_type = 2
}
else if ( runShell('grep -n "gfx1030" rocminfo.log') ) {
arch_type = 3
}
else if ( runShell('grep -n "gfx1101" rocminfo.log') ) {
arch_type = 4
}
else if ( runShell('grep -n "gfx1201" rocminfo.log') ) {
arch_type = 5
}
cmake_build(conf)
dir("build"){
//run tests and examples
//sh 'make -j check'
if (params.RUN_PERFORMANCE_TESTS && do_perf_tests == 0 ){
//we only need the ckProfiler to run the performance tests, so we pack and stash it
//do not stash profiler on nodes where we don't need to run performance tests
sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler'
stash name: "ckProfiler.tar.gz"
}
if (params.RUN_FULL_QA && do_perf_tests == 0 ){
// build deb packages for all gfx9 targets and prepare to export
if (params.RUN_FULL_QA && arch_type == 1 ){
// build deb packages for all gfx9 targets on gfx90a system and prepare to export
echo "Build ckProfiler package"
sh 'make -j package'
archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb'
archiveArtifacts artifacts: 'composablekernel-tests_*.deb'
sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb'
stash name: "ckprofiler_0.2.0_amd64.deb"
stash includes: "ckprofiler_0.2.0_amd64.deb", name: "ckprofiler_0.2.0_amd64.deb"
}
}
// run performance tests, stash the logs, results will be processed on the master node
dir("script"){
if (params.RUN_PERFORMANCE_TESTS){
if (params.RUN_FULL_QA && arch_type == 1){
// run full tests on gfx90a
echo "Run full performance tests"
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
archiveArtifacts "perf_batched_gemm.log"
archiveArtifacts "perf_grouped_gemm.log"
archiveArtifacts "perf_grouped_conv_fwd.log"
archiveArtifacts "perf_grouped_conv_bwd_data.log"
archiveArtifacts "perf_grouped_conv_bwd_weight.log"
archiveArtifacts "perf_gemm_bilinear.log"
archiveArtifacts "perf_reduction.log"
archiveArtifacts "perf_splitK_gemm.log"
archiveArtifacts "perf_onnx_gemm.log"
archiveArtifacts "perf_mixed_gemm.log"
stash includes: "perf_**.log", name: "perf_log"
}
else if ( arch_type == 1 ){
// run standard tests on gfx90a
echo "Run performance tests"
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_onnx_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
stash includes: "perf_**.log", name: "perf_log"
}
// disable performance tests on gfx1030 for now.
//else if ( arch_type == 3){
// run basic tests on gfx1030
// echo "Run gemm performance tests"
// sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10"
// archiveArtifacts "perf_onnx_gemm_gfx10.log"
// stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10"
//}
else if ( arch_type == 4){
// run basic tests on gfx11
echo "Run gemm performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11"
archiveArtifacts "perf_onnx_gemm_gfx11.log"
stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11"
}
else if ( arch_type == 5 ){
// run basic tests on gfx12
echo "Run gemm performance tests"
sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12"
archiveArtifacts "perf_onnx_gemm_gfx12.log"
stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12"
}
}
}
if (params.hipTensor_test && do_perf_tests == 0 ){
//build and test hipTensor
if (params.hipTensor_test && arch_type == 1 ){
// build and test hipTensor on gfx90a node
sh """#!/bin/bash
rm -rf "${params.hipTensor_branch}".zip
rm -rf hipTensor-"${params.hipTensor_branch}"
......@@ -630,11 +566,9 @@ def Build_CK(Map conf=[:]){
ls -ltr
CC=hipcc CXX=hipcc cmake -Bbuild . -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install"
cmake --build build -- -j
ctest --test-dir build
"""
}
dir("hipTensor-${params.hipTensor_branch}/build"){
sh 'ctest'
}
}
}
}
......@@ -684,15 +618,13 @@ def process_results(Map conf=[:]){
}
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 1, unit: 'HOURS'){
timeout(time: 15, unit: 'MINUTES'){
try{
dir("script"){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
unstash "perf_fmha_fwd_gfx942.log"
unstash "perf_fmha_bwd_gfx942.log"
unstash "perf_fmha_fwd_gfx90a.log"
unstash "perf_fmha_bwd_gfx90a.log"
unstash "perf_fmha_log_gfx942"
unstash "perf_fmha_log_gfx90a"
}
catch(Exception err){
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
......@@ -702,26 +634,26 @@ def process_results(Map conf=[:]){
// unstash perf files to master
unstash "ckprofiler_0.2.0_amd64.deb"
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
unstash "perf_gemm.log"
unstash "perf_resnet50_N256.log"
unstash "perf_resnet50_N4.log"
unstash "perf_batched_gemm.log"
unstash "perf_grouped_gemm.log"
unstash "perf_grouped_conv_fwd.log"
unstash "perf_grouped_conv_bwd_data.log"
unstash "perf_grouped_conv_bwd_weight.log"
unstash "perf_gemm_bilinear.log"
unstash "perf_reduction.log"
unstash "perf_splitK_gemm.log"
unstash "perf_onnx_gemm.log"
unstash "perf_mixed_gemm.log"
unstash "perf_log"
try{
unstash "perf_log_gfx11"
unstash "perf_log_gfx12"
}
catch(Exception err){
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
}
sh "./process_qa_data.sh"
}
else{
// unstash perf files to master
unstash "perf_gemm.log"
unstash "perf_resnet50_N256.log"
unstash "perf_resnet50_N4.log"
unstash "perf_log"
try{
unstash "perf_log_gfx11"
unstash "perf_log_gfx12"
}
catch(Exception err){
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
}
sh "./process_perf_data.sh"
}
}
......@@ -739,10 +671,10 @@ def process_results(Map conf=[:]){
}
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.2;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
0 21 * * * % ROCMVERSION=6.2;hipTensor_test=true;RUN_CODEGEN_TESTS=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;BUILD_GFX12=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true
0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false
0 13 * * * % BUILD_LEGACY_OS=true''' : ""
......@@ -765,8 +697,8 @@ pipeline {
description: 'If you want to use a custom docker image, please specify it here (default: leave blank).')
string(
name: 'ROCMVERSION',
defaultValue: '6.2',
description: 'Specify which ROCM version to use: 6.2 (default).')
defaultValue: '6.3',
description: 'Specify which ROCM version to use: 6.3 (default).')
string(
name: 'COMPILER_VERSION',
defaultValue: '',
......@@ -829,8 +761,8 @@ pipeline {
description: "Test building instances for various architectures simultaneously (default: OFF)")
booleanParam(
name: "BUILD_GFX12",
defaultValue: false,
description: "Build CK and run tests on gfx12 (default: OFF)")
defaultValue: true,
description: "Build CK and run tests on gfx12 (default: ON)")
booleanParam(
name: "NINJA_BUILD_TRACE",
defaultValue: false,
......@@ -1240,29 +1172,6 @@ pipeline {
}
}
}
stage("Performance Tests")
{
parallel
{
stage("Run ckProfiler: gfx90a")
{
when {
beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
}
options { retry(1) }
agent{ label rocmnode("gfx90a")}
environment{
setup_args = "NO_CK_BUILD"
}
steps{
runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release')
cleanWs()
}
}
}
}
stage("Process Performance Test Results")
{
parallel
......
......@@ -4,6 +4,7 @@
#include <hip/hip_runtime_api.h>
#include <memory>
#include <string>
#include <stdexcept>
namespace rtc {
......
rocm-docs-core==1.10.0
rocm-docs-core==1.11.0
sphinxcontrib-bibtex==2.6.3
......@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.10.0
rocm-docs-core==1.11.0
# via -r requirements.in
six==1.16.0
# via pybtex
......
......@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
0, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
0, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
......
......@@ -172,12 +172,13 @@ bool run_grouped_conv_fwd(bool do_verification,
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
// values generated: -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 6});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{-5.0, 5.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-1.0, 1.0});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
......
......@@ -2,10 +2,17 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
DTYPE_MAP = {
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}
MASK_IMPL = {
......
......@@ -283,7 +283,7 @@ class FmhaBwdApiPool:
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic])
......@@ -360,7 +360,7 @@ class FmhaBwdDQDKDVKernel:
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
......@@ -469,7 +469,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list()
api_pool = FmhaBwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
......@@ -585,7 +585,7 @@ class FmhaBwdOGradDotOKernel:
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_spad = BOOL_MAP[self.F_spad],
F_dvpad = BOOL_MAP[self.F_dvpad],
F_mode = MODE_MAP[self.F_mode],
......@@ -616,7 +616,7 @@ def get_bwd_dot_do_o_blobs() -> List[FmhaBwdOGradDotOKernel]:
gen = list()
for dtype in DTYPE_MAP.keys():
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
......@@ -716,7 +716,7 @@ class FmhaBwdConvertQGradKernel:
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_spad = BOOL_MAP[self.F_spad],
......@@ -751,7 +751,7 @@ def get_bwd_convert_dq_blobs() -> List[FmhaBwdConvertQGradKernel]:
gen = list()
for dtype in DTYPE_MAP.keys():
for dtype in BWD_DTYPE_MAP.keys():
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -282,7 +282,7 @@ class FmhaFwdApiPool:
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
......@@ -301,7 +301,7 @@ class FmhaFwdTileSize:
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
......@@ -339,7 +339,7 @@ class FmhaFwdKernel:
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
......@@ -462,6 +462,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
......@@ -469,7 +472,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
......@@ -216,7 +216,7 @@ class FmhaFwdAppendKVKernel:
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
......@@ -301,6 +301,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
......@@ -308,7 +311,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -112,7 +112,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
}}
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
{F_dvpad}>;
#include <iostream>
......@@ -161,7 +161,7 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
{F_hdim},
{F_bm0},
{F_bm0},
{F_bn1},
{F_mode},
fmha_trait>;
......@@ -231,11 +231,11 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
if(s.log_level_ > 0)
std::cout
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
<< std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
);
}}
......@@ -431,11 +431,11 @@ class FmhaFwdSplitKVApiPool:
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
......@@ -472,7 +472,7 @@ class FmhaFwdSplitKVKernel:
FMHA_FWD_SPLITKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
......@@ -492,7 +492,7 @@ class FmhaFwdSplitKVKernel:
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
......@@ -552,7 +552,7 @@ class FmhaFwdSplitKVCombineKernel:
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = DTYPE_MAP[self.F_dtype],
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn1 = self.F_tile.F_bn1,
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
......@@ -626,7 +626,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
# TODO: use async pipeline when compiler is more stable
# TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask))
......@@ -645,6 +645,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
elif dtype in ['fp8', 'bf8']:
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask))
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
......@@ -652,7 +655,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
gen = list()
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
prefill_tiles = codegen.ops.fmha_fwd.get_fmha_fwd_tile_dict_from_dtype(dtype)
decode_tiles = get_fmha_fwd_tile_dict_from_dtype(dtype)
if decode_tiles == None:
......@@ -722,7 +725,7 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
gen = list()
for dtype in DTYPE_MAP.keys():
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
if d == None:
continue
......
......@@ -101,7 +101,7 @@ auto create_args(int argc, char* argv[])
}
// different threshold for different dtype
template <typename DataType>
template <typename DataTypeConfig>
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-2;
......@@ -110,7 +110,7 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
}
template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
......@@ -122,7 +122,7 @@ auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataType>
template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
......@@ -209,7 +209,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
using TypeConfig = FmhaBwdTypeConfig<DataType>;
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
......@@ -933,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref,
std::string("Error: QGrad Incorrect results!"),
......@@ -986,11 +986,11 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
}
return -3;
......
......@@ -14,11 +14,19 @@
#include <utility>
#include <variant>
struct FmhaBwdFp16
{
};
struct FmhaBwdBf16
{
};
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<ck_tile::half_t>
struct FmhaBwdTypeConfig<FmhaBwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
......@@ -38,7 +46,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
};
template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
struct FmhaBwdTypeConfig<FmhaBwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
......
......@@ -3,6 +3,7 @@
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ref/naive_attention.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "utils.hpp"
......@@ -42,7 +43,7 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
......@@ -143,7 +144,7 @@ auto create_args(int argc, char* argv[])
}
// different threshold for different dtype
template <typename DataType>
template <typename DataTypeConfig>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
......@@ -152,7 +153,7 @@ auto get_elimit(std::string /*init_method*/)
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
......@@ -160,7 +161,7 @@ auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
auto get_elimit<FmhaFwdFp8>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
......@@ -240,7 +241,7 @@ int override_num_splits_if_necessary(int batch,
return num_splits;
}
template <typename DataType>
template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
......@@ -284,8 +285,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
std::is_same_v<DataType, ck_tile::bf16_t>))
if constexpr(!(std::is_same_v<DataTypeConfig, FmhaFwdFp16> ||
std::is_same_v<DataTypeConfig, FmhaFwdBf16>))
{
if(0 < rotary_dim)
{
......@@ -407,25 +408,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
return atoi(squant_str.c_str()) != 0 ? true : false;
}();
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<DataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max);
scale_p = dtype_max / range_p;
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
scale_o = range_p * range_v / range_o / dtype_max;
}
std::string vlayout = arg_parser.get_str("vlayout");
bool lse = arg_parser.get_bool("lse");
......@@ -445,7 +427,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
bool s_randval = false;
if(p_drop > 0.0f && do_validation)
if(p_drop > 0.0f && do_validation != 0)
{
s_randval = true;
}
......@@ -478,7 +460,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
using TypeConfig = FmhaFwdTypeConfig<DataType>;
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
......@@ -492,6 +474,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
using OaccDataType = typename TypeConfig::OaccDataType;
using ODataType = typename TypeConfig::ODataType;
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
float p_dtype_max = v_dtype_max; // assume p and v is the same type
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
float scale_p = 1.f;
float scale_o = 1.f;
if(squant)
{
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
scale_p = p_dtype_max / range_p;
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
}
// accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q =
......@@ -695,14 +699,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(init_method == "ufq" || init_method == "uf:q" ||
init_method == "3") // suitable for fp8 quantization
{
ck_tile::FillUniformDistribution<QDataType>{-dtype_max, dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-dtype_max, dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-dtype_max, dtype_max, seed}(vnew_host);
ck_tile::FillUniformDistribution<QDataType>{-q_dtype_max, q_dtype_max, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(k_host);
ck_tile::FillUniformDistribution<KDataType>{-k_dtype_max, k_dtype_max, seed}(knew_host);
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(v_host);
ck_tile::FillUniformDistribution<VDataType>{-v_dtype_max, v_dtype_max, seed}(vnew_host);
// bias_fp8 = qscale_bias * bias_fp32
float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k);
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
// Assume bias is in [-1.f, 1.f] in original fp32
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
}
......@@ -1104,25 +1108,75 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush;
if(!do_validation)
if(do_validation == 0)
{
std::cout << std::flush << std::endl;
return true;
}
if(do_validation == 2)
{
// NOTE: use gpu to do validation
ck_tile::naive_attention_fwd_traits naive_t;
naive_t.q_type = data_type;
naive_t.k_type = data_type;
naive_t.v_type = data_type;
naive_t.o_type = data_type;
naive_t.q_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.k_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.v_layout = i_perm == 1 ? "bhsd" : "bshd";
naive_t.o_layout = o_perm == 1 ? "bhsd" : "bshd";
naive_t.variation = 0; // TODO?
ck_tile::DeviceMem o_naive_buf(o_host.get_element_space_size_in_bytes());
ck_tile::naive_attention_fwd_args naive_a;
naive_a.q_ptr = q_buf.GetDeviceBuffer();
naive_a.k_ptr = k_buf.GetDeviceBuffer();
naive_a.v_ptr = v_buf.GetDeviceBuffer();
naive_a.o_ptr = o_naive_buf.GetDeviceBuffer();
naive_a.scale_s = scale_s;
naive_a.context_len_ptr = nullptr; // used when seqlen kv come from a pointer
naive_a.page_table_ptr =
nullptr; // [batch, num_blocks] seqlen_kv is in different block(paged attn)
naive_a.hdim = hdim_q;
naive_a.hdim_v = hdim_v; // could be cross-attn, where V and Q/K hdim are different
naive_a.batch_q = batch;
naive_a.batch_kv = batch;
naive_a.batch_ratio_kv = 1; // batch_q / batch_kv
naive_a.seqlen_q = seqlen_qs[0];
naive_a.seqlen_kv = seqlen_ks[0]; // if context_len_ptr is not nullptr, ignore this field
naive_a.nhead_q = nhead;
naive_a.nhead_kv = nhead_k;
naive_a.nhead_ratio_kv = naive_a.nhead_q / naive_a.nhead_kv; // nhead_q / nhead_kv
naive_a.page_size = 0; // if paged, the seqlen-kv for each block
ck_tile::stream_config naive_s{};
naive_attention_fwd(naive_t, naive_a, naive_s);
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
std::cout << ", valid:" << (pass_ ? "y" : "n") << std::flush << std::endl;
return pass_;
}
o_buf.FromDevice(o_host.data());
lse_buf.FromDevice(lse_host.data());
randval_buf.FromDevice(randval_host.data());
auto p_compute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::scales{scale_p};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
if constexpr(std::is_same_v<DataTypeConfig, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
else
......@@ -1172,7 +1226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
decltype(q_host_ref) q_host_ref_ro(q_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] =
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], real_seqlen_q);
ck_tile::reference_batched_rotary_position_embedding(
......@@ -1188,13 +1242,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[0] / nr, i[1] % page_block_size, i[2]);
});
} else {
} else {
k_host_ref.ForEach([&](auto& self, auto i) {
self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]);
});
}
} else
#endif
#endif
{
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
......@@ -1215,7 +1269,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
knew_host_ref_ro.emplace(knew_host_ref.get_lengths());
auto [rotary_cos_slice, rotary_sin_slice] =
auto [rotary_cos_slice, rotary_sin_slice] =
slice_rotary_cos_sin(rotary_cos_host, rotary_sin_host, cache_seqlen_ks[wb], seqlen_knew);
ck_tile::reference_batched_rotary_position_embedding(
......@@ -1237,19 +1291,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(0 < page_block_size) {
if(is_v_rowmajor) {
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[2] % page_block_size, i[1]);
});
} else {
v_host_ref.ForEach([&](auto& self, auto i) {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[2] % page_block_size, i[0] / nr, i[1]);
});
}
}
else
else
{
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {
if(i_perm) {
v_host_ref.ForEach([&](auto& self, auto i) {
self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size);
});
} else {
......@@ -1444,7 +1498,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass;
......@@ -1501,15 +1555,15 @@ int main(int argc, char* argv[])
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
return run<FmhaFwdFp16>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
return run<FmhaFwdBf16>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
return run<FmhaFwdFp8>(arg_parser) ? 0 : -2;
}
return -3;
......
......@@ -16,11 +16,35 @@
#include <utility>
#include <variant>
struct FmhaFwdFp16
{
};
struct FmhaFwdBf16
{
};
struct FmhaFwdFp8
{
};
struct FmhaFwdBf8
{
};
struct FmhaFwdFp8Fp16
{
};
struct FmhaFwdFp8Bf16
{
};
template <typename DataType>
struct FmhaFwdTypeConfig;
template <>
struct FmhaFwdTypeConfig<ck_tile::half_t>
struct FmhaFwdTypeConfig<FmhaFwdFp16>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
......@@ -36,7 +60,7 @@ struct FmhaFwdTypeConfig<ck_tile::half_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bf16_t>
struct FmhaFwdTypeConfig<FmhaFwdBf16>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
......@@ -52,7 +76,7 @@ struct FmhaFwdTypeConfig<ck_tile::bf16_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::fp8_t>
struct FmhaFwdTypeConfig<FmhaFwdFp8>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
......@@ -68,7 +92,7 @@ struct FmhaFwdTypeConfig<ck_tile::fp8_t>
};
template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t>
struct FmhaFwdTypeConfig<FmhaFwdBf8>
{
using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t;
......
......@@ -92,6 +92,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......
......@@ -119,6 +119,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
......
......@@ -35,7 +35,8 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
......@@ -49,11 +50,14 @@ auto create_args(int argc, char* argv[])
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
......@@ -68,14 +72,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
......@@ -116,7 +120,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.GetDeviceBuffer(),
m,
n,
stride};
x_stride,
y_stride};
auto kargs = Kernel::MakeKargs(args);
......@@ -133,7 +138,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1});
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
......@@ -183,7 +188,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
......@@ -195,10 +200,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
......@@ -210,8 +217,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
<< std::endl;
}
return pass;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment