Unverified Commit f9466a75 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into rimadduri/grouped_gemm_async_memcpy

parents 89d8fca1 44828b7c
...@@ -77,10 +77,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- ...@@ -77,10 +77,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
# Remove unnecessary rocm components that take a lot of space # Remove unnecessary rocm components that take a lot of space
apt-get remove -y rocblas rocfft rocsparse composablekernel-dev apt-get remove -y rocblas rocfft rocsparse composablekernel-dev
# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1
RUN if [ "$ROCMVERSION" = "6.1" ]; then \
sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev"; \
fi
# Update the cmake to version 3.27.5 # Update the cmake to version 3.27.5
RUN pip install --upgrade cmake==3.27.5 && \ RUN pip install --upgrade cmake==3.27.5 && \
#Install latest ccache #Install latest ccache
......
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub20.04_rocm6.2"
FROM $BASE_DOCKER
ARG compiler_version=""
ARG compiler_commit=""
# Add alternative compilers, if necessary
ENV compiler_version=$compiler_version
ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'" && \
sh -c "echo compiler commit = '$compiler_commit'"
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 16 ; \
else echo "using the release compiler"; \
fi
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 16 ; \
else echo "using the release compiler"; \
fi
...@@ -32,41 +32,42 @@ def runShell(String command){ ...@@ -32,41 +32,42 @@ def runShell(String command){
return (output != "") return (output != "")
} }
def getDockerImageName(){ def getBaseDockerImageName(){
def img def img
if (params.USE_CUSTOM_DOCKER != ""){ if (params.USE_CUSTOM_DOCKER != ""){
img = "${params.USE_CUSTOM_DOCKER}" img = "${params.USE_CUSTOM_DOCKER}"
} }
else{ else{
if (params.ROCMVERSION != "6.3"){ if (params.ROCMVERSION != "6.3"){
if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" }
} else{
else{ img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}"
if (params.COMPILER_COMMIT == ""){ }
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" }
} return img
else{ }
def commit = "${params.COMPILER_COMMIT}"[0..6]
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" def getDockerImageName(){
} def img
} def base_name = getBaseDockerImageName()
if (params.USE_CUSTOM_DOCKER != ""){
img = "${params.USE_CUSTOM_DOCKER}"
} }
else{ else{
if (params.COMPILER_VERSION == "") { if (params.COMPILER_VERSION == "") {
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}" img = "${base_name}"
} }
else{ else{
if (params.COMPILER_COMMIT == ""){ if (params.COMPILER_COMMIT == ""){
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" img = "${base_name}_${params.COMPILER_VERSION}"
} }
else{ else{
def commit = "${params.COMPILER_COMMIT}"[0..6] def commit = "${params.COMPILER_COMMIT}"[0..6]
img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" img = "${base_name}_${params.COMPILER_VERSION}_${commit}"
} }
} }
} }
}
return img return img
} }
...@@ -131,17 +132,21 @@ def buildDocker(install_prefix){ ...@@ -131,17 +132,21 @@ def buildDocker(install_prefix){
env.DOCKER_BUILDKIT=1 env.DOCKER_BUILDKIT=1
checkout scm checkout scm
def image_name = getDockerImageName() def image_name = getDockerImageName()
def base_image_name = getBaseDockerImageName()
echo "Building Docker for ${image_name}" echo "Building Docker for ${image_name}"
def dockerArgs = "--squash --build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' --build-arg DISABLE_CACHE='git rev-parse ${params.COMPILER_VERSION}' " def dockerArgs = "--build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
dockerArgs = dockerArgs + " --no-cache " dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . "
}
else{
dockerArgs = dockerArgs + " -f Dockerfile . "
} }
echo "Build Args: ${dockerArgs}" echo "Build Args: ${dockerArgs}"
try{ try{
if(params.BUILD_DOCKER){ if(params.BUILD_DOCKER){
//force building the new docker if that parameter is true //force building the new docker if that parameter is true
echo "Building image: ${image_name}" echo "Building image: ${image_name}"
retimage = docker.build("${image_name}", dockerArgs + ' .') retimage = docker.build("${image_name}", dockerArgs)
withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) {
retimage.push() retimage.push()
} }
......
rocm-docs-core==1.9.0 rocm-docs-core==1.9.2
sphinxcontrib-bibtex==2.6.3 sphinxcontrib-bibtex==2.6.3
...@@ -103,7 +103,7 @@ requests==2.32.3 ...@@ -103,7 +103,7 @@ requests==2.32.3
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.9.0 rocm-docs-core==1.9.2
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via pybtex # via pybtex
......
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_mem_pipeline EXCLUDE_FROM_ALL gemm_mem_pipeline.cpp) add_executable(tile_example_universal_gemm EXCLUDE_FROM_ALL universal_gemm.cpp)
...@@ -14,10 +14,17 @@ ...@@ -14,10 +14,17 @@
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "gemm_basic.hpp" #include "gemm_basic.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
#if 1 #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler // Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32; constexpr ck_tile::index_t N_Tile = 32;
...@@ -30,7 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -30,7 +37,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32; constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
#else
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
// Compute friendly for Intrawave scheduler // Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256; constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256; constexpr ck_tile::index_t N_Tile = 256;
...@@ -63,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -63,8 +71,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<
#endif
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
...@@ -77,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -77,13 +88,21 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<
#endif
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
GemmShape, GemmShape,
Traits, Traits,
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
ck_tile::GemmPipelineScheduler::Interwave, ck_tile::GemmPipelineScheduler::Interwave,
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
ck_tile::GemmPipelineScheduler::Intrawave,
#endif
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
......
add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp)
# Batched GEMM
This folder contains example for batched GEMM using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable `build/bin/tile_example_batched_gemm`
## example
```
args:
-m m dimension (default:256)
-n n dimension (default:128)
-k k dimension (default:128)
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
-stride_a Tensor A stride (default:128)
-stride_b Tensor B stride (default:128)
-stride_c Tensor C stride (default:128)
-batch_stride_a Batch A stride (default:32768)
-batch_stride_b Batch B stride (default:16384)
-batch_stride_c Batch C stride (default:32768)
-batch_count Batch count (default:16)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2;
constexpr int kBlockPerCu = 1;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 128;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr bool CShuffleEpilogue =
std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTilePartitioner<CodegenGemmShape>;
using GemmEpilogue = std::conditional_t<
CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType,
kPadM,
kPadN,
kTilePermute,
kOutputRank,
1,
0,
TilePartitioner::kM,
TilePartitioner::kN>>,
ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
#include "run_batched_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
template <typename DataType>
struct BatchedGemmTypeConfig;
template <>
struct BatchedGemmTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
using Types = BatchedGemmTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "128", "n dimension")
.insert("k", "128", "k dimension")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "R", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("batch_stride_a", "32768", "Batch A stride")
.insert("batch_stride_b", "16384", "Batch B stride")
.insert("batch_stride_c", "32768", "Batch C stride")
.insert("batch_count", "16", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t batch_stride_A,
ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count,
int n_warmup,
int n_repeat)
{
batched_gemm_kargs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.batch_stride_A = batch_stride_A;
args.batch_stride_B = batch_stride_B;
args.batch_stride_C = batch_stride_C;
args.batch_count = batch_count;
float ave_time = batched_gemm<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Batched Gemm"};
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_byte = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * N * K +
sizeof(CDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B
<< " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
int run_batched_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
using namespace ck_tile::literals;
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
std::size_t row,
std::size_t col,
std::size_t stride,
std::size_t batch_stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{batch_stride, 1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
stride_C = f_get_default_stride(M, N, stride_C, c_layout);
ck_tile::HostTensor<ADataType> a_m_k(
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
c_m_n_host_ref.SetZero();
const auto b_n_k = b_k_n.transpose({0, 2, 1});
ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_n_k, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_gpu_buf_ref,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
int run_batched_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "R")
{
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Col{}, Row{});
// }
// else if(a_layout == "C" && b_layout == "R")
// {
// return run_batched_gemm_example_with_layouts(argc, argv, Col{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
...@@ -15,4 +15,4 @@ add_subdirectory(12_smoothquant) ...@@ -15,4 +15,4 @@ add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting) add_subdirectory(13_moe_sorting)
add_subdirectory(14_moe_smoothquant) add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe) add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm)
...@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
return; return;
} }
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(d_A,
a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(d_B,
b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{
ADataType* d_ATemp = d_A + batch_id * batch_stride_A;
BDataType* d_BTemp = d_B + batch_id * batch_stride_B;
CDataType* d_CTemp = d_C + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
......
...@@ -41,13 +41,16 @@ struct BlockUniversalGemmAsBsCr ...@@ -41,13 +41,16 @@ struct BlockUniversalGemmAsBsCr
static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}), using I0 = number<0>;
using I1 = number<1>;
static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!"); "Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}), static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!"); "Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}), static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!"); "Error! WarpGemm's M is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}), static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!"); "Error! WarpGemm's N is not consisten with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
...@@ -99,6 +102,9 @@ struct BlockUniversalGemmAsBsCr ...@@ -99,6 +102,9 @@ struct BlockUniversalGemmAsBsCr
static constexpr auto Scheduler = Traits::Scheduler; static constexpr auto Scheduler = Traits::Scheduler;
using I0 = number<0>;
using I1 = number<1>;
private: private:
template <GemmPipelineScheduler Scheduler, typename GemmTraits> template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl struct BlockGemmImpl
...@@ -114,35 +120,31 @@ struct BlockUniversalGemmAsBsCr ...@@ -114,35 +120,31 @@ struct BlockUniversalGemmAsBsCr
const ASmemBlockWindow& a_block_window, const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>, "The CDataType as defined in traits should be the same as correspoinding "
"The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!");
"C block tensor data type!"); static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
static_assert(std::is_same_v<typename GemmTraits::ADataType, std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in " "The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"); "traits should be the same as correspoinding block window data type!");
static_assert( static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
"MPerBlock, NPerBlock, KPerBlock defined in " "MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"); " BlockGemmShape are different from A/B block smem windows apropriate dims!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
// TODO: refactor warp_window tile type to class member as it should be // TODO: refactor warp_window tile type to class member as it should be
// compile-time known information. // compile-time known information.
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
a_block_window.get_window_origin() + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>; using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
...@@ -156,16 +158,15 @@ struct BlockUniversalGemmAsBsCr ...@@ -156,16 +158,15 @@ struct BlockUniversalGemmAsBsCr
statically_indexed_array< statically_indexed_array<
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>, statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_windows; a_warp_windows;
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
b_block_window.get_window_origin() + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>; using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
...@@ -179,10 +180,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -179,10 +180,10 @@ struct BlockUniversalGemmAsBsCr
statically_indexed_array< statically_indexed_array<
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>, statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_windows; b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp; a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
...@@ -193,7 +194,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -193,7 +194,7 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp; b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
...@@ -203,8 +204,8 @@ struct BlockUniversalGemmAsBsCr ...@@ -203,8 +204,8 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
...@@ -212,10 +213,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -212,10 +213,10 @@ struct BlockUniversalGemmAsBsCr
// hot loop: // hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter)); const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter)); const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor- // read C warp tensor from C block tensor-
...@@ -226,7 +227,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -226,7 +227,7 @@ struct BlockUniversalGemmAsBsCr
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM // warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile); WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -243,13 +244,13 @@ struct BlockUniversalGemmAsBsCr ...@@ -243,13 +244,13 @@ struct BlockUniversalGemmAsBsCr
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits> struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{ {
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, GemmTraits::KIterPerWarp>, statically_indexed_array<typename GemmTraits::AWarpTile, KIterPerWarp>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_tiles_; a_warp_tiles_;
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, GemmTraits::KIterPerWarp>, statically_indexed_array<typename GemmTraits::BWarpTile, KIterPerWarp>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_tiles_; b_warp_tiles_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow> template <typename ASmemBlockWindow, typename BSmemBlockWindow>
...@@ -257,30 +258,27 @@ struct BlockUniversalGemmAsBsCr ...@@ -257,30 +258,27 @@ struct BlockUniversalGemmAsBsCr
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
"MPerBlock, NPerBlock, KPerBlock defined in " "MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"); " BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType, static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
typename ASmemBlockWindow::DataType> && std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in " "The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"); "traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
// TODO: refactor warp_window tile type to class member as it should be // TODO: refactor warp_window tile type to class member as it should be
// compile-time known information. // compile-time known information.
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
a_block_window.get_window_origin() + a_block_window.get_window_origin() + multi_index<2>{iMWarp * WarpGemm::kM, 0},
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0}, make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>; using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
...@@ -292,18 +290,16 @@ struct BlockUniversalGemmAsBsCr ...@@ -292,18 +290,16 @@ struct BlockUniversalGemmAsBsCr
AWarpWindow{}.get_window_lengths(), AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!"); "AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array< statically_indexed_array<statically_indexed_array<AWarpWindow, KIterPerWarp>,
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>, MIterPerWarp>
GemmTraits::MIterPerWarp>
a_warp_windows; a_warp_windows;
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
b_block_window.get_window_origin() + b_block_window.get_window_origin() + multi_index<2>{iNWarp * WarpGemm::kN, 0},
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0}, make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>; using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
...@@ -315,13 +311,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -315,13 +311,12 @@ struct BlockUniversalGemmAsBsCr
BWarpWindow{}.get_window_lengths(), BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!"); "BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array< statically_indexed_array<statically_indexed_array<BWarpWindow, KIterPerWarp>,
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>, NIterPerWarp>
GemmTraits::NIterPerWarp>
b_warp_windows; b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp; a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// TODO: I don't have to move 0,0 window! // TODO: I don't have to move 0,0 window!
...@@ -331,8 +326,8 @@ struct BlockUniversalGemmAsBsCr ...@@ -331,8 +326,8 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp; b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter), move_tile_window(b_warp_windows(nIter)(kIter),
...@@ -341,12 +336,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -341,12 +336,12 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window // read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window // read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
}); });
...@@ -359,22 +354,21 @@ struct BlockUniversalGemmAsBsCr ...@@ -359,22 +354,21 @@ struct BlockUniversalGemmAsBsCr
[[maybe_unused]] const ASmemBlockWindow& a_block_window, [[maybe_unused]] const ASmemBlockWindow& a_block_window,
[[maybe_unused]] const BSmemBlockWindow& b_block_window) [[maybe_unused]] const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>, "The CDataType as defined in traits should be the same as correspoinding "
"The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!");
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{}; constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop: // hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor- // read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor; CWarpTensor c_warp_tensor;
...@@ -383,9 +377,9 @@ struct BlockUniversalGemmAsBsCr ...@@ -383,9 +377,9 @@ struct BlockUniversalGemmAsBsCr
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM // warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kIter], a_warp_tiles_[mIter][kIter],
b_warp_tiles_[nIter][kIter]); b_warp_tiles_[nIter][kIter]);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -412,12 +406,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -412,12 +406,12 @@ struct BlockUniversalGemmAsBsCr
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>, statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_tiles_; a_warp_tiles_;
statically_indexed_array< statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>, statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_tiles_; b_warp_tiles_;
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow> template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
...@@ -425,30 +419,28 @@ struct BlockUniversalGemmAsBsCr ...@@ -425,30 +419,28 @@ struct BlockUniversalGemmAsBsCr
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] && GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[I0{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}], GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[I1{}],
"MPerBlock, NPerBlock, KPerBlock defined in " "MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!"); " BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType, static_assert(std::is_same_v<ADataType, typename ASmemBlockWindow::DataType> &&
typename ASmemBlockWindow::DataType> && std::is_same_v<BDataType, typename BSmemBlockWindow::DataType>,
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in " "The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!"); "traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp; const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp); const index_t iNWarp = get_warp_id() - (iMWarp * NWarp);
// TODO: refactor warp_window tile type to class member as it should be // TODO: refactor warp_window tile type to class member as it should be
// compile-time known information. // compile-time known information.
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
a_block_window.get_window_origin() + a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop}, multi_index<2>{iMWarp * WarpGemm::kM, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{})); make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>; using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
...@@ -461,16 +453,16 @@ struct BlockUniversalGemmAsBsCr ...@@ -461,16 +453,16 @@ struct BlockUniversalGemmAsBsCr
"AWarpWindow lengths must be equal to AWarpTile lengths!"); "AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>, statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
GemmTraits::MIterPerWarp> MIterPerWarp>
a_warp_windows; a_warp_windows;
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}), make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
b_block_window.get_window_origin() + b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop}, multi_index<2>{iNWarp * WarpGemm::kN, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{})); make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>; using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
...@@ -483,10 +475,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -483,10 +475,10 @@ struct BlockUniversalGemmAsBsCr
"BWarpWindow lengths must be equal to BWarpTile lengths!"); "BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>, statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
GemmTraits::NIterPerWarp> NIterPerWarp>
b_warp_windows; b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp; a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
...@@ -496,7 +488,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -496,7 +488,7 @@ struct BlockUniversalGemmAsBsCr
}); });
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp; b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
...@@ -508,11 +500,11 @@ struct BlockUniversalGemmAsBsCr ...@@ -508,11 +500,11 @@ struct BlockUniversalGemmAsBsCr
// TODO check if a_warp_tiles has same desc as a_warp_window // TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window // read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter)); load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
}); });
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window // read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter)); load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
}); });
...@@ -525,13 +517,12 @@ struct BlockUniversalGemmAsBsCr ...@@ -525,13 +517,12 @@ struct BlockUniversalGemmAsBsCr
const ASmemBlockWindow& a_block_window, const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
static_assert( static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>, "The CDataType as defined in traits should be the same as correspoinding "
"The CDataType as defined in traits should be the same as correspoinding " "C block tensor data type!");
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths = constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
...@@ -555,8 +546,8 @@ struct BlockUniversalGemmAsBsCr ...@@ -555,8 +546,8 @@ struct BlockUniversalGemmAsBsCr
} }
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) { static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor- // read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor; CWarpTensor c_warp_tensor;
...@@ -573,17 +564,17 @@ struct BlockUniversalGemmAsBsCr ...@@ -573,17 +564,17 @@ struct BlockUniversalGemmAsBsCr
// penalty // penalty
if constexpr(kIter.value == KRepeat - 1 && if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 && kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == GemmTraits::MIterPerWarp - 1 && mIter.value == MIterPerWarp - 1 &&
nIter.value == GemmTraits::NIterPerWarp - 1) nIter.value == NIterPerWarp - 1)
{ {
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
block_sync_lds(); block_sync_lds();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
} }
// warp GEMM // warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kInnerIter], a_warp_tiles_[mIter][kInnerIter],
b_warp_tiles_[nIter][kInnerIter]); b_warp_tiles_[nIter][kInnerIter]);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -632,7 +623,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -632,7 +623,7 @@ struct BlockUniversalGemmAsBsCr
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
block_gemm_impl_.template LocalPrefetch(a_block_window, b_block_window); block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
} }
// C += A * B // C += A * B
...@@ -641,7 +632,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -641,7 +632,7 @@ struct BlockUniversalGemmAsBsCr
const ASmemBlockWindow& a_block_window, const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window); block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
} }
// C = A * B // C = A * B
...@@ -650,7 +641,7 @@ struct BlockUniversalGemmAsBsCr ...@@ -650,7 +641,7 @@ struct BlockUniversalGemmAsBsCr
const BSmemBlockWindow& b_block_window) const BSmemBlockWindow& b_block_window)
{ {
auto c_block_tensor = MakeCBlockTile(); auto c_block_tensor = MakeCBlockTile();
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window); block_gemm_impl_(c_block_tensor, a_block_window, b_block_window);
return c_block_tensor; return c_block_tensor;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
struct BatchedGemmHostArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
};
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
struct BatchedGemmKargs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
index_t batch_stride_A;
index_t batch_stride_B;
index_t batch_stride_C;
index_t batch_count;
};
using Kargs = BatchedGemmKargs;
using Hargs = BatchedGemmHostArgs;
__host__ static constexpr auto GridSize(const Hargs& h)
{
return TilePartitioner::GridSize(h.M, h.N, h.batch_count);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr BatchedGemmKargs MakeKargs(const Hargs& h)
{
Kargs k;
k.a_ptr = h.a_ptr;
k.b_ptr = h.b_ptr;
k.c_ptr = h.c_ptr;
k.M = h.M;
k.N = h.N;
k.K = h.K;
k.stride_A = h.stride_A;
k.stride_B = h.stride_B;
k.stride_C = h.stride_C;
k.batch_stride_A = h.batch_stride_A;
k.batch_stride_B = h.batch_stride_B;
k.batch_stride_C = h.batch_stride_C;
k.batch_count = h.batch_count;
return k;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.z);
// options
const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A);
const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A);
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B);
const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start + batch_offset_A,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<1>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_start + batch_offset_B,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
auto a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
auto b_pad_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
// clang-format on
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
const auto batch_stride_C = __builtin_amdgcn_readfirstlane(kargs.batch_stride_C);
const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_C);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start + batch_offset_C,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(c_block_window, c_block_tile);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmPipelineAgBgCrImplBase
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
template <typename DstBlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
}
template <typename ADramBlockWindowTmp, typename ALdsTensorView>
CK_TILE_DEVICE auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const ALdsTensorView& a_lds_block_view) const
{
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block_view,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
auto a_lds_gemm_window = make_tile_window(
a_lds_block_view, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
return make_tuple(std::move(a_copy_dram_window),
std::move(a_copy_lds_window),
std::move(a_lds_gemm_window));
}
template <typename BDramBlockWindowTmp, typename BLdsTensorView>
CK_TILE_DEVICE auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BLdsTensorView& b_lds_block_view) const
{
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
auto b_lds_gemm_window = make_tile_window(
b_lds_block_view, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
return make_tuple(std::move(b_copy_dram_window),
std::move(b_copy_lds_window),
std::move(b_lds_gemm_window));
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV3
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
};
// Compute optimized pipeline
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA;
static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL;
constexpr index_t B_LDS_Read_Width = KPerXDL;
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * VectorSizeA);
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * VectorSizeB);
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a = A_LDS_Read_Width * sizeof(ADataType) == 16
? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b = B_LDS_Read_Width * sizeof(BDataType) == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 =
num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A/B tiles in LDS
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
// B DRAM tile window for load
// B LDS tile window for store
// B LDS tile for block GEMM
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_block_tile;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window);
Base::GlobalPrefetch(b_block_tile, b_copy_dram_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
return c_block_tile;
}
};
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem);
}
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem);
}
};
} // namespace ck_tile
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -90,7 +91,8 @@ struct BaseGemmPipelineAgBgCrMem
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrMem<Problem>; using Base = BaseGemmPipelineAgBgCrMem<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
...@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -103,8 +105,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>; using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using I0 = number<0>; using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
...@@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -124,46 +127,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using Base::PrefetchStages; using Base::PrefetchStages;
CK_TILE_HOST_DEVICE constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl struct PipelineImpl : public PipelineImplBase
{ {
}; };
template <> template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{ {
template <typename DstBlockTile, typename SrcTileWindow> using Base = PipelineImplBase;
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <bool HasHotLoop, template <bool HasHotLoop,
TailNumber TailNum, TailNumber TailNum,
...@@ -185,66 +162,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -185,66 +162,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); " or KPerBlock!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
// A tile in LDS // A/B tiles in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); // With c++20 could simplify to below line.
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); // Currently get error: captured structured bindings are a C++20 extension
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
// TODO: LDS alignment should come from Policy! auto& a_lds_block = ab_lds_blocks.at(I0{});
constexpr index_t a_lds_block_space_size_aligned = auto& b_lds_block = ab_lds_blocks.at(I1{});
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = // A LDS tile for block GEMM
make_tile_window(a_lds_block, auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), auto& a_copy_dram_window = a_windows.at(I0{});
{0, 0}, auto& a_copy_lds_window = a_windows.at(I1{});
a_copy_dram_window.get_tile_distribution()); auto& a_lds_gemm_window = a_windows.at(I2{});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B DRAM tile window for load
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM // B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window( auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
// Block GEMM // Block GEMM
auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
...@@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -266,20 +215,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch // prefetch
// global read 0 // global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
}); });
// main body // main body
...@@ -295,19 +244,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -295,19 +244,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds(); block_sync_lds();
LocalPrefill( Base::LocalPrefill(
a_copy_lds_window, a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func); a_element_func);
LocalPrefill( Base::LocalPrefill(
b_copy_lds_window, b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func); b_element_func);
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window); a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window); b_copy_dram_window);
}); });
i += PrefetchStages; i += PrefetchStages;
...@@ -323,12 +272,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -323,12 +272,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds(); block_sync_lds();
LocalPrefill(a_copy_lds_window, Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}), a_block_tiles.get(number<prefetch_idx>{}),
a_element_func); a_element_func);
LocalPrefill(b_copy_lds_window, Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}), b_block_tiles.get(number<prefetch_idx>{}),
b_element_func); b_element_func);
}); });
block_sync_lds(); block_sync_lds();
...@@ -376,24 +325,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -376,24 +325,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
}; };
template <> template <>
struct PipelineImpl<GemmPipelineScheduler::Interwave> struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
{ {
template <typename DstBlockTile, typename SrcTileWindow> using Base = PipelineImplBase;
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile(dst_block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <bool HasHotLoop, template <bool HasHotLoop,
TailNumber TailNum, TailNumber TailNum,
...@@ -415,66 +349,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -415,66 +349,38 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate " "A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"); "([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}],
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock" "A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"); " or KPerBlock!");
// ------------------------------------------------------------------------------------ // ------------------------------------------------------------------------------------
// Definitions of all needed tiles // Definitions of all needed tiles
// A tile in LDS // A/B tiles in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem); // With c++20 could simplify to below line.
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); // Currently get error: captured structured bindings are a C++20 extension
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); // auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
auto ab_lds_blocks = Base::GetABLdsTensorViews(p_smem);
// TODO: LDS alignment should come from Policy! auto& a_lds_block = ab_lds_blocks.at(I0{});
constexpr index_t a_lds_block_space_size_aligned = auto& b_lds_block = ab_lds_blocks.at(I1{});
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load // A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = // A LDS tile for block GEMM
make_tile_window(a_lds_block, auto a_windows = Base::GetAWindows(a_dram_block_window_tmp, a_lds_block);
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), auto& a_copy_dram_window = a_windows.at(I0{});
{0, 0}, auto& a_copy_lds_window = a_windows.at(I1{});
a_copy_dram_window.get_tile_distribution()); auto& a_lds_gemm_window = a_windows.at(I2{});
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B DRAM tile window for load
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM // B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window( auto b_windows = Base::GetBWindows(b_dram_block_window_tmp, b_lds_block);
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); auto& b_copy_dram_window = b_windows.at(I0{});
auto& b_copy_lds_window = b_windows.at(I1{});
auto& b_lds_gemm_window = b_windows.at(I2{});
// Block GEMM // Block GEMM
auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
...@@ -496,20 +402,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -496,20 +402,20 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// prefetch // prefetch
// global read 0 // global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C // initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [1, PrefetchStages] // Global prefetch [1, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window); Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window); Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
}); });
// main body // main body
...@@ -523,19 +429,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -523,19 +429,19 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave // no second block_sync_lds because it's interwave
LocalPrefill( Base::LocalPrefill(
a_copy_lds_window, a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func); a_element_func);
LocalPrefill( Base::LocalPrefill(
b_copy_lds_window, b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func); b_element_func);
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window); a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), Base::GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window); b_copy_dram_window);
}); });
i += PrefetchStages; i += PrefetchStages;
...@@ -548,12 +454,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -548,12 +454,12 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
// no second block_sync_lds because it's interwave // no second block_sync_lds because it's interwave
LocalPrefill(a_copy_lds_window, Base::LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<prefetch_idx>{}), a_block_tiles.get(number<prefetch_idx>{}),
a_element_func); a_element_func);
LocalPrefill(b_copy_lds_window, Base::LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<prefetch_idx>{}), b_block_tiles.get(number<prefetch_idx>{}),
b_element_func); b_element_func);
}); });
block_sync_lds(); block_sync_lds();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment