Commit 93c11557 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

CK Tile Batched Gemm

parent 3c171550
add_executable(tile_example_batched_gemm_basic EXCLUDE_FROM_ALL batched_gemm_basic.cpp) add_executable(tile_example_batched_gemm EXCLUDE_FROM_ALL batched_gemm.cpp)
\ No newline at end of file \ No newline at end of file
# GEMM Matrix Multiplication # Batched GEMM
This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile. This folder contains example for batched GEMM using ck_tile tile-programming implementation.
## build ## build
``` ```
...@@ -8,20 +8,23 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ...@@ -8,20 +8,23 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch> sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j make tile_example_batched_gemm -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_batched_gemm`
## example ## example
``` ```
args: args:
-b batch size (default:1) -m m dimension (default:256)
-m m dimension (default:1024) -n n dimension (default:128)
-n n dimension (default:2048) -k k dimension (default:128)
-k k dimension (default:64) -stride_a Tensor A stride (default:128)
-stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:128)
-stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:128)
-stride_c Tensor C stride (default:0) -batch_stride_a Batch A stride (default:32768)
-batch_stride_b Batch B stride (default:16384)
-batch_stride_c Batch C stride (default:32768)
-batch_count Batch count (default:16)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5) -e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16) -prec data type. fp16/bf16/fp8/bf8 (default:fp16)
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp" #include "ck_tile/host.hpp"
#include "batched_gemm_basic.hpp" #include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const batched_gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const batched_gemm_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
......
...@@ -10,10 +10,10 @@ ...@@ -10,10 +10,10 @@
#include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/kernel_launch.hpp"
template <typename DataType> template <typename DataType>
struct GemmBasicTypeConfig; struct BatchedGemmTypeConfig;
template <> template <>
struct GemmBasicTypeConfig<ck_tile::half_t> struct BatchedGemmTypeConfig<ck_tile::half_t>
{ {
using ADataType = ck_tile::half_t; using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t; using BDataType = ck_tile::half_t;
...@@ -43,7 +43,7 @@ struct DataTypeTraits<ck_tile::half_t> ...@@ -43,7 +43,7 @@ struct DataTypeTraits<ck_tile::half_t>
static constexpr const char* name = "fp16"; static constexpr const char* name = "fp16";
}; };
using Types = GemmBasicTypeConfig<ck_tile::half_t>; using Types = BatchedGemmTypeConfig<ck_tile::half_t>;
// Specific type aliases for easy access // Specific type aliases for easy access
using ADataType = Types::ADataType; using ADataType = Types::ADataType;
...@@ -51,12 +51,11 @@ using BDataType = Types::BDataType; ...@@ -51,12 +51,11 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct batched_gemm_basic_args struct batched_gemm_args
{ {
const void* p_a; const void* p_a;
const void* p_b; const void* p_b;
void* p_c; void* p_c;
ck_tile::index_t kbatch;
ck_tile::index_t M; ck_tile::index_t M;
ck_tile::index_t N; ck_tile::index_t N;
ck_tile::index_t K; ck_tile::index_t K;
...@@ -72,8 +71,7 @@ struct batched_gemm_basic_args ...@@ -72,8 +71,7 @@ struct batched_gemm_basic_args
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("b", "1", "batch size") arg_parser.insert("m", "256", "m dimension")
.insert("m", "256", "m dimension")
.insert("n", "128", "n dimension") .insert("n", "128", "n dimension")
.insert("k", "128", "k dimension") .insert("k", "128", "k dimension")
.insert("stride_a", "128", "Tensor A stride") .insert("stride_a", "128", "Tensor A stride")
...@@ -94,4 +92,4 @@ auto create_args(int argc, char* argv[]) ...@@ -94,4 +92,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float gemm_calc(batched_gemm_basic_args args, const ck_tile::stream_config& s); float gemm_calc(batched_gemm_args args, const ck_tile::stream_config& s);
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M, ck_tile::index_t M,
...@@ -13,7 +13,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -13,7 +13,6 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t stride_A, ck_tile::index_t stride_A,
ck_tile::index_t stride_B, ck_tile::index_t stride_B,
ck_tile::index_t stride_C, ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ck_tile::index_t batch_stride_A, ck_tile::index_t batch_stride_A,
ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C, ck_tile::index_t batch_stride_C,
...@@ -21,11 +20,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -21,11 +20,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
batched_gemm_basic_args args; batched_gemm_args args;
args.p_a = a_m_k_dev_buf.GetDeviceBuffer(); args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
args.p_b = b_k_n_dev_buf.GetDeviceBuffer(); args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
args.p_c = c_m_n_dev_buf.GetDeviceBuffer(); args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
args.kbatch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -70,19 +68,11 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -70,19 +68,11 @@ int run_batched_gemm_example(int argc, char* argv[])
ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t batch_size = arg_parser.get_int("b");
ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a"); ck_tile::index_t batch_stride_A = arg_parser.get_int("batch_stride_a");
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
std::cout << "Received args: " << std::endl;
std::cout << "batch_stride_A: " << batch_stride_A << '\n'
<< "batch_stride_B: " << batch_stride_B << '\n'
<< "batch_stride_C: " << batch_stride_C << '\n'
<< "batch_count: " << batch_count << std::endl;
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -92,16 +82,19 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -92,16 +82,19 @@ int run_batched_gemm_example(int argc, char* argv[])
using namespace ck_tile::literals; using namespace ck_tile::literals;
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor = [](std::size_t batch_count_,
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{ {
return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col}, return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{row * col, stride, 1_uz}); {row * col, stride, 1_uz});
} }
else else
{ {
return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col}, return ck_tile::HostTensorDescriptor({batch_count_, row, col},
{row * col, 1_uz, stride}); {row * col, 1_uz, stride});
} }
}; };
...@@ -130,10 +123,12 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -130,10 +123,12 @@ int run_batched_gemm_example(int argc, char* argv[])
stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); f_host_tensor_descriptor(batch_count, M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(
f_host_tensor_descriptor(batch_count, K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result( ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, CLayout{}));
// TODO: add different init types // TODO: add different init types
...@@ -149,7 +144,7 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -149,7 +144,7 @@ int run_batched_gemm_example(int argc, char* argv[])
c_m_n_dev_buf.SetZero(); c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero(); c_m_n_dev_result.SetZero();
invoke_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf, invoke_batched_gemm<ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf, b_k_n_dev_buf,
c_m_n_dev_buf, c_m_n_dev_buf,
M, M,
...@@ -158,7 +153,6 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -158,7 +153,6 @@ int run_batched_gemm_example(int argc, char* argv[])
stride_A, stride_A,
stride_B, stride_B,
stride_C, stride_C,
batch_size,
batch_stride_A, batch_stride_A,
batch_stride_B, batch_stride_B,
batch_stride_C, batch_stride_C,
...@@ -172,10 +166,10 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -172,10 +166,10 @@ int run_batched_gemm_example(int argc, char* argv[])
if(arg_parser.get_int("v") == 1) if(arg_parser.get_int("v") == 1)
{ {
ck_tile::HostTensor<CDataType> c_m_n_host_ref( ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero(); c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_batched_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref); a_m_k, b_k_n, c_m_n_host_ref);
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref);
...@@ -185,12 +179,12 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -185,12 +179,12 @@ int run_batched_gemm_example(int argc, char* argv[])
else if(arg_parser.get_int("v") == 2) else if(arg_parser.get_int("v") == 2)
{ {
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref( ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{})); f_host_tensor_descriptor(batch_count, M, N, stride_C, CLayout{}));
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType, ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
...@@ -213,7 +207,7 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -213,7 +207,7 @@ int run_batched_gemm_example(int argc, char* argv[])
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
} }
return pass; return pass;
......
...@@ -201,4 +201,117 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -201,4 +201,117 @@ void reference_gemm_gpu(DeviceMem& a_device,
return; return;
} }
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_batched_gemm_gpu(DeviceMem& a_device,
DeviceMem& b_device,
DeviceMem& c_device,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(d_A,
a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(d_B,
b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(int i = 0; i < batch_count; ++i)
{
ADataType* d_ATemp = d_A + i * batch_stride_A;
BDataType* d_BTemp = d_B + i * batch_stride_B;
CDataType* d_CTemp = d_C + i * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -96,13 +96,6 @@ struct BatchedGemmKernel ...@@ -96,13 +96,6 @@ struct BatchedGemmKernel
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr) + const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr) +
__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B); __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B);
// Convert pointers to tensor views // Convert pointers to tensor views
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A));
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B));
// }
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment