Commit 8bd49370 authored by Adam Osewski's avatar Adam Osewski
Browse files

Refactoring & Move Layout info to pipeline problem.

parent d3689b06
......@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_basic.hpp"
#include <hip/hip_runtime.h>
#include <cstring>
......@@ -11,6 +10,11 @@
#include <string>
#include <tuple>
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
......@@ -22,7 +26,6 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("e", "1e-5", "Absolute error tolerance")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
......@@ -51,13 +54,11 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel =
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args.p_a,
args.p_b,
args.p_c,
args.epsilon,
args.M,
args.N,
args.K,
......@@ -96,7 +97,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
return -1; // Or handle the error appropriately
}
float epsilon = arg_parser.get_float("e");
ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
......@@ -110,66 +110,34 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
args.p_a = a_buf.GetDeviceBuffer();
args.p_b = b_buf.GetDeviceBuffer();
args.p_c = c_buf.GetDeviceBuffer();
args.epsilon = epsilon;
args.kbatch = batch_size;
args.M = M;
args.N = N;
args.K = K;
// Only set stride_M and stride_N if they are non-zero and not equal to K.
if(stride_a != 0)
{
args.stride_A = stride_a;
}
else
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
args.stride_A = [&]() {
if constexpr(std::is_same_v<LayoutA, ck_tile::tensor_layout::gemm::ColumnMajor>)
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return M;
return col;
}
else
{
return K;
return row;
}
}();
}
if(stride_b != 0)
{
args.stride_B = stride_b;
}
else
{
args.stride_B = [&]() {
if constexpr(std::is_same_v<LayoutB, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return N;
}
else
{
return K;
}
}();
}
return stride;
};
if(stride_c != 0)
{
args.stride_C = stride_c;
}
else
{
args.stride_C = [&]() {
if constexpr(std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
return M;
}
else
{
return N;
}
}();
}
args.stride_A = f_get_default_stride(M, K, stride_a, LayoutA{});
args.stride_B = f_get_default_stride(K, N, stride_b, LayoutB{});
args.stride_C = f_get_default_stride(M, N, stride_c, LayoutC{});
float ave_time = gemm_calc<LayoutA, LayoutB, LayoutC, PipelineProblem, GemmPipeline, GemmShape>(
args, ck_tile::stream_config{nullptr, true});
......@@ -197,30 +165,57 @@ int main(int argc, char* argv[])
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
// The Matrix Multiplication goes with Matrix A (M, K), Matrix B (N, K) = Matrix C (M, N).
using matrix_a_layout = ck_tile::tensor_layout::gemm::RowMajor;
using matrix_b_layout = ck_tile::tensor_layout::gemm::ColumnMajor;
using matrix_c_layout = ck_tile::tensor_layout::gemm::RowMajor;
// host verify
std::vector<int> a_dimensions =
(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, K}
: std::vector<int>{K, M};
std::vector<int> b_dimensions =
(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
? std::vector<int>{N, K}
: std::vector<int>{K, N};
std::vector<int> c_dimensions =
(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::RowMajor>)
? std::vector<int>{M, N}
: std::vector<int>{N, M};
ck_tile::HostTensor<ADataType> a_host(a_dimensions);
ck_tile::HostTensor<BDataType> b_host(b_dimensions);
ck_tile::HostTensor<CDataType> c_host_ref(c_dimensions);
ck_tile::HostTensor<CDataType> c_host_dev(c_dimensions);
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
using namespace ck_tile::literals;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if(stride == 0)
{
// give a chance if stride is zero, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_host(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_host(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_host_ref(f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::HostTensor<CDataType> c_host_dev(f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_host);
......@@ -259,6 +254,9 @@ int main(int argc, char* argv[])
BDataType,
AccDataType,
CodegenGemmShape,
ALayout,
BLayout,
CLayout,
kPadA,
kPadB,
kPadC>;
......@@ -266,9 +264,9 @@ int main(int argc, char* argv[])
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
invoke_gemm<ck_tile::half_t,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout,
ALayout,
BLayout,
CLayout,
CodegenPipelineProblem,
CodegenGemmPipeline,
CodegenGemmShape>(a_buf, b_buf, c_buf, arg_parser);
......@@ -280,17 +278,12 @@ int main(int argc, char* argv[])
if(arg_parser.get_int("v") == 1)
{
// ToDo: Will Add the Element Op (bias) verification in the future.
ck_tile::reference_gemm<ADataType,
BDataType,
AccDataType,
CDataType,
matrix_a_layout,
matrix_b_layout,
matrix_c_layout>(a_host, b_host, c_host_ref);
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_host, b_host, c_host_ref);
pass_cpu = ck_tile::check_err(c_host_dev, c_host_ref);
std::cout << "The CPU veification result is:" << (pass_cpu ? "correct" : "fail")
std::cout << "The CPU verification result is:" << (pass_cpu ? "correct" : "fail")
<< std::flush;
}
......@@ -298,57 +291,19 @@ int main(int argc, char* argv[])
if(arg_parser.get_int("v") == 2)
{
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
if(stride_a == 0)
{
if constexpr(std::is_same_v<matrix_a_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_a = M;
}
else
{
stride_a = K;
}
}
if(stride_b == 0)
{
if constexpr(std::is_same_v<matrix_b_layout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_b = N;
}
else
{
stride_b = K;
}
}
if(stride_c == 0)
{
if constexpr(std::is_same_v<matrix_c_layout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_c = M;
}
else
{
stride_c = N;
}
}
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
ck_tile::HostTensor<CDataType> c_host_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
c_gpu_buf.SetZero();
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
a_buf, b_buf, c_gpu_buf, M, N, K, stride_A, stride_B, stride_C);
c_buf.FromDevice(c_host_gpu_ref.data());
c_gpu_buf.FromDevice(c_host_gpu_ref.data());
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
std::cout << "The GPU verification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush;
}
......
......@@ -4,12 +4,10 @@
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include <string>
template <typename DataType>
struct GemmBasicTypeConfig;
......@@ -58,7 +56,6 @@ struct gemm_basic_args
const void* p_a;
const void* p_b;
void* p_c;
float epsilon;
ck_tile::index_t kbatch;
ck_tile::index_t M;
ck_tile::index_t N;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace ck_tile {
namespace literals {
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline constexpr std::size_t operator""_uz(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
inline constexpr std::size_t operator""_zu(unsigned long long size)
{
return static_cast<std::size_t>(size);
}
} // namespace literals
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace ck_tile {
......@@ -14,48 +15,36 @@ template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const HostTensor<BDataType>& b_n_k,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const int N = b_n_k.mDesc.get_lengths()[0];
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[1]
: a_m_k.mDesc.get_lengths()[0];
const int M = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_m_k.mDesc.get_lengths()[0]
: a_m_k.mDesc.get_lengths()[1];
auto f = [&](auto m) {
for(int n = 0; n < N; ++n)
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
for(int k = 0; k < K; ++k)
for(std::size_t k = 0; k < K; ++k)
{
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? a_element_op(a_m_k(m, k))
: a_element_op(a_m_k(k, m));
BDataType v_b = b_element_op(b_n_k(n, k));
ADataType v_a = a_element_op(a_m_k(m, k));
BDataType v_b = b_element_op(b_k_n(k, n));
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
ck_tile::type_convert<AccDataType>(v_b);
v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
}
};
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
......
......@@ -8,34 +8,29 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
template <typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::CDataType>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
return TilePartitioner::GridSize(M, N, KBatch);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
......@@ -45,30 +40,30 @@ struct GemmKernel
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
{
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
......@@ -79,12 +74,12 @@ struct GemmKernel
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>)
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
......@@ -93,14 +88,14 @@ struct GemmKernel
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
......@@ -110,7 +105,7 @@ struct GemmKernel
number<1>{});
}
else
{ // Default NK layout
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
......@@ -123,8 +118,8 @@ struct GemmKernel
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadA ? 1 : 0 > {});
sequence < false,
GemmPipeline::kPadA ? true : false > {});
auto ABlockWindow = make_tile_window(
a_pad_view,
......@@ -134,8 +129,8 @@ struct GemmKernel
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadB ? 1 : 0 > {});
sequence < false,
GemmPipeline::kPadB ? true : false > {});
auto BBlockWindow = make_tile_window(
b_pad_view,
......@@ -225,15 +220,15 @@ struct GemmKernel
}
}
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
......@@ -242,7 +237,7 @@ struct GemmKernel
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{},
number<1>{});
}
......@@ -251,13 +246,13 @@ struct GemmKernel
auto c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0,
GemmPipeline::kPadC ? 1 : 0 > {});
auto CBlockWindow_pad = make_tile_window(
sequence < false,
GemmPipeline::kPadC ? true : false > {});
auto CBlockWindow = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
EpiloguePipeline{}(CBlockWindow, c_block_tile);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = BlockGemmPipelineAgBgCrDefaultPolicy>
struct BlockGemmPipelineAgBgCrMem
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
using CBlockTile = typename BlockGemm::CBlockTile;
using I0 = number<0>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t AlignmentA = Problem::AlignmentA;
static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t WgpPerCU =
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2
? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
: 2;
static constexpr index_t LocalPrefillStages = 1;
static constexpr index_t GlobalBufferNum = PrefetchStages;
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
if(num_loop % PrefetchStages == 1)
{
return TailNumber::One;
}
else if(num_loop % PrefetchStages == 2)
{
return TailNumber::Two;
}
else if(num_loop % PrefetchStages == 3)
{
return TailNumber::Three;
}
else if(num_loop % PrefetchStages == 4)
{
return TailNumber::Four;
}
else if(num_loop % PrefetchStages == 5)
{
return TailNumber::Five;
}
else if(num_loop % PrefetchStages == 6)
{
return TailNumber::Six;
}
else if(num_loop % PrefetchStages == 7)
{
return TailNumber::Seven;
}
else
{
return TailNumber::Full;
}
}
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetStaticLdsSize()
{
return ck_tile::integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <BlockGemmPipelineScheduler Scheduler>
struct PipelineImpl
{
};
template <>
struct PipelineImpl<BlockGemmPipelineScheduler::Intrawave>
{
template <typename BlockTile, typename SrcTileWindow>
CK_TILE_DEVICE void GlobalPrefetch(BlockTile& block_tile,
SrcTileWindow& dram_tile_window) const
{
load_tile_raw(block_tile, dram_tile_window);
move_tile_window(dram_tile_window, {0, KPerBlock});
}
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
const SrcBlockTile& src_block_tile,
const ElementFunction& element_func) const
{
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
store_tile(lds_tile_window, block_tile_tmp);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem,
CBlockTile& c_block_tile) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
NPerBlock ==
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned =
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window =
make_tile_window(a_lds_block,
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load
auto b_copy_dram_window =
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store
auto b_copy_lds_window =
make_tile_window(b_lds_block,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM
constexpr auto block_gemm = BlockGemm();
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
// prefetch
// global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(I0{}), b_copy_dram_window);
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func);
LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func);
// Global prefetch [2, PrefetchStages]
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}), a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}), b_copy_dram_window);
});
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
LocalPrefill(
a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
LocalPrefill(
b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
GlobalPrefetch(a_block_tiles.get(number<prefetch_idx>{}),
a_copy_dram_window);
GlobalPrefetch(b_block_tiles.get(number<prefetch_idx>{}),
b_copy_dram_window);
});
i += PrefetchStages;
} while(i < (num_loop - PrefetchStages));
}
auto HotLoopTail = [&](auto tail_num) {
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
LocalPrefill(a_copy_lds_window,
a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
a_element_func);
LocalPrefill(b_copy_lds_window,
b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}),
b_element_func);
});
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
};
// TODO: TailNumber2Number
if constexpr(TailNum == TailNumber::One)
{
block_sync_lds();
// block_gemm.LocalPrefetch();
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else if constexpr(TailNum == TailNumber::Two)
{
HotLoopTail(number<2>{});
}
else if constexpr(TailNum == TailNumber::Three)
{
HotLoopTail(number<3>{});
}
else if constexpr(TailNum == TailNumber::Four)
{
HotLoopTail(number<4>{});
}
else if constexpr(TailNum == TailNumber::Five)
{
HotLoopTail(number<5>{});
}
else if constexpr(TailNum == TailNumber::Six)
{
HotLoopTail(number<6>{});
}
else if constexpr(TailNum == TailNumber::Seven)
{
HotLoopTail(number<7>{});
}
else if constexpr(TailNum == TailNumber::Full)
{
HotLoopTail(number<PrefetchStages>{});
}
return c_block_tile;
}
};
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
index_t num_loop,
void* p_smem,
CBlockTile& c_block_tile) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
a_element_func,
b_dram_block_window_tmp,
b_element_func,
num_loop,
p_smem,
c_block_tile);
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
index_t num_loop,
void* p_smem,
CBlockTile& c_block_tile) const
{
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDataType& b) { return b; },
num_loop,
p_smem,
c_block_tile);
}
};
} // namespace ck_tile
......@@ -13,6 +13,9 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
bool kPadA_ = false,
bool kPadB_ = false,
bool kPadC_ = false>
......@@ -23,6 +26,10 @@ struct BlockGemmPipelineProblem
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<ALayout_>;
using BLayout = remove_cvref_t<BLayout_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
......@@ -37,18 +44,29 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
bool kPadA_ = false,
bool kPadB_ = false,
bool kPadC_ = false,
BlockGemmPipelineScheduler Scheduler_ = BlockGemmPipelineScheduler::Intrawave>
struct BlockGemmUniversalPipelineProblem
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = false,
TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<ALayout_>;
using BLayout = remove_cvref_t<BLayout_>;
using CLayout = remove_cvref_t<CLayout_>;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
......
......@@ -4,12 +4,12 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem_custom_policy.hpp"
namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
using BlockGemmPipelineAgBgCrDefaultPolicy = BlockGemmPipelineAgBgCrMemCustomPolicy;
using GemmPipelineAgBgCrDefaultPolicy = GemmPipelineAgBgCrMemCustomPolicy;
} // namespace ck_tile
......@@ -9,14 +9,14 @@
namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
struct BlockGemmPipelineAgBgCrMemCustomPolicy
struct GemmPipelineAgBgCrMemCustomPolicy
{
// 3d + padding
template <typename Problem>
......@@ -47,8 +47,6 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
......@@ -69,7 +67,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
......@@ -77,7 +75,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeB()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
......@@ -85,7 +83,7 @@ struct BlockGemmPipelineAgBgCrMemCustomPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
......
......@@ -3,9 +3,11 @@
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
enum struct BlockGemmPipelineScheduler
enum struct GemmPipelineScheduler
{
Intrawave,
Interwave,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment