Commit 93c11557 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

CK Tile Batched Gemm

parent 3c171550
add_executable(tile_example_batched_gemm_basic EXCLUDE_FROM_ALL batched_gemm_basic.cpp) add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp)
\ No newline at end of file \ No newline at end of file
# GEMM Matrix Multiplication # Batched GEMM
This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile. This folder contains example for batched GEMM using ck_tile tile-programming implementation.
## build ## build
``` ```
...@@ -8,24 +8,27 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ...@@ -8,24 +8,27 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch> sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j make tile_example_batched_gemm -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_batched_gemm`
## example ## example
``` ```
args: args:
-b batch size (default:1) -m m dimension (default:256)
-m m dimension (default:1024) -n n dimension (default:128)
-n n dimension (default:2048) -k k dimension (default:128)
-k k dimension (default:64) -stride_a Tensor A stride (default:128)
-stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:128)
-stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:128)
-stride_c Tensor C stride (default:0) -batch_stride_a Batch A stride (default:32768)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) -batch_stride_b Batch B stride (default:16384)
-e Absolute error tolerance (default:1e-5) -batch_stride_c Batch C stride (default:32768)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16) -batch_count Batch count (default:16)
-warmup number of iterations before benchmark the kernel (default:10) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-repeat number of iterations to benchmark the kernel (default:100) -e Absolute error tolerance (default:1e-5)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu) -prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
``` ```
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "batched_gemm_basic.hpp" #include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const batched_gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const batched_gemm_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
......
...@@ -10,10 +10,10 @@ ...@@ -10,10 +10,10 @@
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct BatchedGemmTypeConfig;
template <> template <>
struct GemmBasicTypeConfig<ck_tile::half_t> struct BatchedGemmTypeConfig<ck_tile::half_t>
{ {
using ADataType = ck_tile::half_t; using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t; using BDataType = ck_tile::half_t;
...@@ -43,7 +43,7 @@ struct DataTypeTraits<ck_tile::half_t> ...@@ -43,7 +43,7 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16"; static constexpr const char* name = "fp16";
}; };
using Types = GemmBasicTypeConfig<ck_tile::half_t>; using Types = BatchedGemmTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access // Specific type aliases for easy access
using ADataType = Types::ADataType; using ADataType = Types::ADataType;
...@@ -51,12 +51,11 @@ using BDataType = Types::BDataType; ...@@ -51,12 +51,11 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct batched_gemm_basic_args struct batched_gemm_args
{ {
const void* p_a; const void* p_a;
const void* p_b; const void* p_b;
void* p_c; void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M; ck_tile::index_t M;
ck_tile::index_t N; ck_tile::index_t N;
ck_tile::index_t K; ck_tile::index_t K;
...@@ -72,8 +71,7 @@ struct batched_gemm_basic_args ...@@ -72,8 +71,7 @@ struct batched_gemm_basic_args
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("m", "256", "m dimension")
.insert("m", "256", "m dimension")
.insert("n", "128", "n dimension") .insert("n", "128", "n dimension")
.insert("k", "128", "k dimension") .insert("k", "128", "k dimension")
.insert("stride_a", "128", "Tensor A stride") .insert("stride_a", "128", "Tensor A stride")
...@@ -94,4 +92,4 @@ auto create_args(int argc, char* argv[]) ...@@ -94,4 +92,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float gemm_calc(batched_gemm_basic_args args, const ck_tile::stream_config& s); float gemm_calc(batched_gemm_args args, const ck_tile::stream_config& s);
...@@ -4,28 +4,26 @@ ...@@ -4,28 +4,26 @@
#pragma once #pragma once
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M, ck_tile::index_t M,
ck_tile::index_t N, ck_tile::index_t N,
ck_tile::index_t K, ck_tile::index_t K,
ck_tile::index_t stride_A, ck_tile::index_t stride_A,
ck_tile::index_t stride_B, ck_tile::index_t stride_B,
ck_tile::index_t stride_C, ck_tile::index_t stride_C,
ck_tile::index_t kbatch, ck_tile::index_t batch_stride_A,
ck_tile::index_t batch_stride_A, ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_stride_C, ck_tile::index_t batch_count,
ck_tile::index_t batch_count, int n_warmup,
int n_warmup, int n_repeat)
int n_repeat)
{ {
batched_gemm_basic_args args; batched_gemm_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -70,19 +68,11 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -70,19 +68,11 @@ int run_batched_gemm_example(int argc, char* argv[])
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a"); ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
std::cout << "Received args: " << std::endl;
std::cout << "batch_stride_A: " << batch_stride_A << '\n'
<< "batch_stride_B: " << batch_stride_B << '\n'
<< "batch_stride_C: " << batch_stride_C << '\n'
<< "batch_count: " << batch_count << std::endl;
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -92,19 +82,22 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -92,19 +82,22 @@ int run_batched_gemm_example(int argc, char* argv[])
using namespace ck_tile::literals; using namespace ck_tile::literals;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor = [](std::size_t batch_count_,
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { std::size_t row,
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>) std::size_t col,
{ std::size_t stride,
return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col}, auto layout) {
{row * col, stride, 1_uz}); if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
} {
else return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{ {row * col, stride, 1_uz});
return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col}, }
{row * col, 1_uz, stride}); else
} {
}; return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{row * col, 1_uz, stride});
}
};
auto f_get_default_stride = [](std::size_t row, auto f_get_default_stride = [](std::size_t row,
std::size_t col, std::size_t col,
...@@ -130,10 +123,12 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -130,10 +123,12 @@ int run_batched_gemm_example(int argc, char* argv[])
stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result( ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, CLayout{}));
// TODO: add different init types // TODO: add different init types
...@@ -149,22 +144,21 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -149,22 +144,21 @@ int run_batched_gemm_example(int argc, char* argv[])
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf, invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_size, batch_stride_A,
batch_stride_A, batch_stride_B,
batch_stride_B, batch_stride_C,
batch_stride_C, batch_count,
batch_count, n_warmup,
n_warmup, n_repeat);
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true; bool pass = true;
...@@ -172,10 +166,10 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -172,10 +166,10 @@ int run_batched_gemm_example(int argc, char* argv[])
if(arg_parser.get_int("v") == 1) if(arg_parser.get_int("v") == 1)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
...@@ -185,35 +179,35 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -185,35 +179,35 @@ int run_batched_gemm_example(int argc, char* argv[])
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref( ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType, ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>(a_m_k_dev_buf, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_gpu_buf_ref, c_m_n_gpu_buf_ref,
M, M,
N, N,
K, K,
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
batch_stride_C, batch_stride_C,
batch_count); batch_count);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
return pass; return pass;
......
...@@ -201,4 +201,117 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -201,4 +201,117 @@ void reference_gemm_gpu(DeviceMem& a_device,
return; return;
} }
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(d_A,
a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(d_B,
b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(int i = 0; i < batch_count; ++i)
{
ADataType* d_ATemp = d_A + i * batch_stride_A;
BDataType* d_BTemp = d_B + i * batch_stride_B;
CDataType* d_CTemp = d_C + i * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -96,13 +96,6 @@ struct BatchedGemmKernel ...@@ -96,13 +96,6 @@ struct BatchedGemmKernel
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr) + const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr) +
__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B); __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B);
// Convert pointers to tensor views // Convert pointers to tensor views
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A));
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B));
// }
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment