Unverified Commit 3aec6f03 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents cdfceb0a 5d671a5f
...@@ -326,13 +326,39 @@ def cmake_build(Map conf=[:]){ ...@@ -326,13 +326,39 @@ def cmake_build(Map conf=[:]){
if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) {
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
} }
//check the node gpu architecture
def arch_type = 0
sh 'rocminfo | tee rocminfo.log'
if ( runShell('grep -n "gfx90a" rocminfo.log') ){
arch_type = 1
}
else if ( runShell('grep -n "gfx942" rocminfo.log') ) {
arch_type = 2
}
if (params.RUN_CK_TILE_FMHA_TESTS){ if (params.RUN_CK_TILE_FMHA_TESTS){
try{ try{
archiveArtifacts "perf_fmha_fwd_*.log" archiveArtifacts "perf_fmha_*.log"
archiveArtifacts "perf_fmha_bwd_*.log" if (arch_type == 1){
stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942"
stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a"
} }
else if (arch_type == 2){
stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942"
}
}
catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
}
}
if (params.RUN_CK_TILE_GEMM_TESTS){
try{
archiveArtifacts "perf_tile_gemm_*.log"
if (arch_type == 1){
stash includes: "perf_tile_gemm_**_fp16_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
}
else if (arch_type == 2){
stash includes: "perf_tile_gemm_**_fp16_gfx942.log", name: "perf_tile_gemm_log_gfx942"
}
}
catch(Exception err){ catch(Exception err){
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
} }
...@@ -630,6 +656,15 @@ def process_results(Map conf=[:]){ ...@@ -630,6 +656,15 @@ def process_results(Map conf=[:]){
echo "could not locate the FMHA performance logs: ${err.getMessage()}." echo "could not locate the FMHA performance logs: ${err.getMessage()}."
} }
} }
if (params.RUN_CK_TILE_GEMM_TESTS){
try{
unstash "perf_tile_gemm_log_gfx942"
unstash "perf_tile_gemm_log_gfx90a"
}
catch(Exception err){
echo "could not locate the GEMM performance logs: ${err.getMessage()}."
}
}
if (params.RUN_FULL_QA){ if (params.RUN_FULL_QA){
// unstash perf files to master // unstash perf files to master
unstash "ckprofiler_0.2.0_amd64.deb" unstash "ckprofiler_0.2.0_amd64.deb"
...@@ -956,7 +991,7 @@ pipeline { ...@@ -956,7 +991,7 @@ pipeline {
environment{ environment{
setup_args = "NO_CK_BUILD" setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 tile_example_gemm_basic && \ make -j64 tile_example_gemm_basic tile_example_gemm_universal && \
cd ../ && cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
} }
...@@ -975,7 +1010,7 @@ pipeline { ...@@ -975,7 +1010,7 @@ pipeline {
environment{ environment{
setup_args = "NO_CK_BUILD" setup_args = "NO_CK_BUILD"
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
make -j64 tile_example_gemm_basic && \ make -j64 tile_example_gemm_basic tile_example_gemm_universal && \
cd ../ && cd ../ &&
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
} }
......
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
...@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch> ...@@ -11,9 +11,9 @@ sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation # The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation # The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j make tile_example_gemm_universal -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`
## example ## example
``` ```
...@@ -22,6 +22,9 @@ args: ...@@ -22,6 +22,9 @@ args:
-m m dimension (default:1024) -m m dimension (default:1024)
-n n dimension (default:2048) -n n dimension (default:2048)
-k k dimension (default:64) -k k dimension (default:64)
-a_layout Tensor A data layout (default: R)
-b_layout Tensor B data layout (default: R)
-c_layout Tensor C data layout (default: R)
-stride_a Tensor A stride (default:0) -stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0) -stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0) -stride_c Tensor C stride (default:0)
......
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
......
...@@ -8,6 +8,27 @@ ...@@ -8,6 +8,27 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct GemmBasicTypeConfig;
......
#!/bin/sh
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
VALID=0
for b_matrix_layout in "R" "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
#!/bin/sh
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
VALID=0
for b_matrix_layout in "R" "C"; do
for m in "64" "512" "1024" "2048"; do
for n in "512" "1024" "2048"; do
for k in "64" "512" "1024" "2048"; do
$EXE -prec=fp16 -b=1 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
done
done
done
done
...@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name ...@@ -19,7 +19,27 @@ echo 'Host name: ' $host_name
export GPU_arch=$4 export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch echo 'GPU_arch: ' $GPU_arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
# get GPU architecture and compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
# run verification tests # run verification tests
example/ck_tile/03_gemm/script/smoke_test.sh example/ck_tile/03_gemm/script/smoke_test_basic.sh
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
# run performance benchmarks
export gemm_basic_log="perf_tile_gemm_basic_fp16_$GPU_arch.log"
print_log_header $gemm_basic_log $env_type $branch $host_name
example/ck_tile/03_gemm/script/benchmark_basic.sh 2>&1 | tee -a $gemm_basic_log
# We do not have a performance benchmark for gemm yet. Will add it in the future. export gemm_mem_pipeline_log="perf_tile_gemm_mem_pipeline_fp16_$GPU_arch.log"
\ No newline at end of file print_log_header $gemm_mem_pipeline_log $env_type $branch $host_name
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh 2>&1 | tee -a $gemm_mem_pipeline_log
#!/bin/bash
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
run_fp16_tests() {
for batch in 1 2; do
for m in 128 1024; do
for n in 128 2048; do
for k in 32 64; do
$EXE -b=$batch -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -e=1e-5 -prec=fp16 $COMMON_ARGS
if [ $? -eq 0 ]; then
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
else
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
# Optionally, exit or break if you need to halt further execution
# exit 1
fi
done
done
done
done
}
set -x
run_fp16_tests
set +x
...@@ -9,18 +9,9 @@ ...@@ -9,18 +9,9 @@
#include <string> #include <string>
#include <tuple> #include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{ {
...@@ -71,12 +62,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -71,12 +62,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using GemmPipelineProblem =
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -89,24 +79,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& ...@@ -89,24 +79,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
#endif
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) scheduler,
ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
......
...@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -52,7 +52,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>; // using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>; // using Vector = ck_tile::sequence<1, 8>;
constexpr ck_tile::index_t kBlockSize = 512; constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1; constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
std::cout << "grid size " << kGridSize << std::endl; std::cout << "grid size " << kGridSize << std::endl;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle ...@@ -1558,14 +1558,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
} }
} }
if(!(arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 && const bool is_w_pad_zero = arg.input_left_pads_[NDimSpatial - 1] == 0 &&
arg.input_right_pads_[NDimSpatial - 1] == 0;
const auto X = arg.filter_spatial_lengths_[NDimSpatial - 1];
const bool XC_access_allowed = arg.Conv_G_ == 1 &&
(arg.Conv_C_ * X) % BBlockTransferSrcScalarPerVector == 0 &&
is_w_pad_zero;
if(!((arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 || XC_access_allowed) &&
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0))
{ {
if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1)) if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1 &&
NumGroupsToMerge > 1))
{ {
return false; return false;
} }
if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1)) if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 &&
NumGroupsToMerge > 1))
{ {
return false; return false;
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -584,6 +584,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
{ {
return false; return false;
} }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
{
return false;
}
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>()) if constexpr(!is_GNWC_GKXC_GNWK<InLayout, WeiLayout, OutLayout>())
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -54,6 +54,19 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -54,6 +54,19 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I12 = Number<12>{};
static constexpr auto I13 = Number<13>{};
static constexpr auto I14 = Number<14>{};
static constexpr auto I16 = Number<16>{};
static constexpr index_t PackedSize = []() { static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>) if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
...@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -198,9 +211,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_oob_thread_scratch_tuple_(thread_scratch_id) src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid); .template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v; dst_vector_type op_r_v;
...@@ -234,13 +244,62 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -234,13 +244,62 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type; using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type; using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
auto src_vector_container = src_vector_type{ using VectorSizeLookupTable = Tuple<Sequence<>,
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)}; Sequence<I1>,
Sequence<I2>,
static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { Sequence<I2, I1>,
// apply the src elementwise op and convert to DstData under the hood if needed Sequence<I4>,
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx), Sequence<I4, I1>,
src_vector_container.template AsType<src_elem_op_vec_t>()[idx]); Sequence<I4, I2>,
Sequence<I4, I2, I1>,
Sequence<I8>,
Sequence<I8, I1>,
Sequence<I8, I2>,
Sequence<I8, I2, I1>,
Sequence<I8, I4>,
Sequence<I8, I4, I1>,
Sequence<I8, I4, I2>,
Sequence<I8, I4, I2, I1>,
Sequence<I16>>;
using VectorOffsetsLookupTable = Tuple<Sequence<>,
Sequence<I0>,
Sequence<I0>,
Sequence<I0, I2>,
Sequence<I0>,
Sequence<I0, I4>,
Sequence<I0, I4>,
Sequence<I0, I4, I6>,
Sequence<I0>,
Sequence<I0, I8>,
Sequence<I0, I8>,
Sequence<I0, I8, I10>,
Sequence<I0, I8>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12, I14>,
Sequence<I0>>;
static_for<0, tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::Size(), 1>{}(
[&](auto v_idx) {
constexpr auto VectorLoadSize =
tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::At(v_idx);
constexpr auto LoadOffset =
tuple_element_t<SrcScalarPerVector, VectorOffsetsLookupTable>::At(v_idx);
using src_vector_container = vector_type_maker_t<SrcData, VectorLoadSize>;
using src_vector_container_t = typename src_vector_container::type;
src_vector_container src_vector =
src_vector_container{src_buf.template Get<src_vector_container_t>(
src_coord_.GetOffset() / PackedSize + LoadOffset, true)};
static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if
// needed
src_element_op_(
op_r_v.template AsType<dst_elem_op_vec_t>()(idx + LoadOffset),
src_vector.template AsType<src_elem_op_vec_t>()[idx]);
});
}); });
// copy data from src_vector_container into src_thread_scratch_ // copy data from src_vector_container into src_thread_scratch_
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -327,7 +327,77 @@ struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>> ...@@ -327,7 +327,77 @@ struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>> struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d3_t __attribute__((ext_vector_type(3)));
using type = d3_t;
union
{
d3_t d3_;
StaticallyIndexedArray<d1_t, 3> d1x3_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
StaticallyIndexedArray<d3_t, 1> d3x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x3_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x3_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x1_;
}
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d3x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -397,7 +467,159 @@ struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>> ...@@ -397,7 +467,159 @@ struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>> struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d5_t __attribute__((ext_vector_type(5)));
using type = d5_t;
union
{
d5_t d5_;
StaticallyIndexedArray<d1_t, 5> d1x5_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
StaticallyIndexedArray<d5_t, 1> d5x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x5_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d5_t>::value)
{
return data_.d5x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x5_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d5_t>::value)
{
return data_.d5x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d7_t __attribute__((ext_vector_type(7)));
using type = d7_t;
union
{
d7_t d7_;
StaticallyIndexedArray<d1_t, 7> d1x7_;
StaticallyIndexedArray<d2_t, 3> d2x3_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
StaticallyIndexedArray<d7_t, 1> d7x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x7_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d7_t>::value)
{
return data_.d7x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x7_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x3_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x1_;
}
else if constexpr(is_same<X, d7_t>::value)
{
return data_.d7x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -479,7 +701,89 @@ struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>> ...@@ -479,7 +701,89 @@ struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 16, typename ck::enable_if_t<is_native_type<T>()>> struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d13_t __attribute__((ext_vector_type(13)));
using type = d13_t;
union
{
d13_t d13_;
StaticallyIndexedArray<d1_t, 13> d1x13_;
StaticallyIndexedArray<d4_t, 3> d4x3_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
StaticallyIndexedArray<d13_t, 1> d13x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x13_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x3_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else if constexpr(is_same<X, d13_t>::value)
{
return data_.d13x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x13_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x3_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else if constexpr(is_same<X, d13_t>::value)
{
return data_.d13x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
......
...@@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return Policy::template GetSmemSizeQ<Problem>();
}
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
...@@ -328,7 +323,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -328,7 +323,6 @@ struct BlockFmhaPipelineQSKSVS
}); });
} }
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail { // tail
block_sync_lds(); block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window); gemm_0(s_acc, q_lds_window, k_lds_window);
...@@ -341,6 +335,10 @@ struct BlockFmhaPipelineQSKSVS ...@@ -341,6 +335,10 @@ struct BlockFmhaPipelineQSKSVS
gemm_0(s_acc, q_lds_window, k_lds_window); gemm_0(s_acc, q_lds_window, k_lds_window);
} }
__builtin_amdgcn_sched_barrier(0);
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, scale_s, add bias, mask, softmax // STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -462,6 +460,12 @@ struct BlockFmhaPipelineQSKSVS ...@@ -462,6 +460,12 @@ struct BlockFmhaPipelineQSKSVS
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{}); block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
__builtin_amdgcn_sched_barrier(0);
// l{j}, Oacc{j} // l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
...@@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS ...@@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS
} }
move_tile_window(v_dram_window, {0, kK1}); move_tile_window(v_dram_window, {0, kK1});
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
......
...@@ -9,11 +9,33 @@ ...@@ -9,11 +9,33 @@
namespace ck_tile { namespace ck_tile {
// This pipeline is qkv all located in LDS // This pipeline is qkv all located in LDS
using BlockFmhaPipelineQSKSVSDefaultPolicy = struct BlockFmhaPipelineQSKSVSDefaultPolicy
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false, : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false, /* AsyncCopyK = */ false,
/* AsyncCopyV = */ false, /* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1, /* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>; /* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
} // namespace ck_tile
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>() + GetSmemSizeK<Problem>(), GetSmemSizeV<Problem>()) +
GetSmemSizeDropout<Problem>();
}
};
} // namespace ck_tile } // namespace ck_tile
...@@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; constexpr index_t kBlockSize = Problem::kBlockSize;
return 16 / sizeof(QDataType); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
} }
template <typename Problem> template <typename Problem>
...@@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t M1 = kBlockSize / get_warp_size(); static_assert(0 < ElemPerThread);
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
...@@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>; typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> && std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> && std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>) std::is_same_v<typename Problem::SaccDataType, float>)
{ {
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
} }
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> && else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> && std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -22,34 +22,19 @@ struct BlockGemmARegBRegCRegV1 ...@@ -22,34 +22,19 @@ struct BlockGemmARegBRegCRegV1
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
// C += A * B static constexpr index_t NPerBlock = BlockGemmShape::kN;
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor> static constexpr index_t KPerBlock = BlockGemmShape::kK;
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t MWarp = config.template at<1>(); CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
constexpr index_t NWarp = config.template at<2>(); {
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
constexpr auto a_block_outer_dstr_encoding = constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>, tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
...@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -57,7 +42,14 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<1, 0>>, tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto b_block_outer_dstr_encoding = constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>, tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
...@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -65,7 +57,14 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<0, 1>>, tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
...@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1 ...@@ -73,15 +72,28 @@ struct BlockGemmARegBRegCRegV1
tuple<sequence<1, 1>>, tuple<sequence<1, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( return c_block_dstr_encode;
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); }
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( // C += A * B
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto a_block_dstr_encode = MakeABlockDistributionEncode();
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = MakeBBlockDistributionEncode();
constexpr auto c_block_dstr_encode = MakeCBlockDistributionEncode();
// check ABC-block-distribution // check ABC-block-distribution
static_assert( static_assert(
...@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1 ...@@ -159,20 +171,6 @@ struct BlockGemmARegBRegCRegV1
CK_TILE_DEVICE static constexpr auto MakeCBlockTile() CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{ {
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>, sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>, tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment