"model/vscode:/vscode.git/clone" did not exist on "4dcf80167aca16c90bd2d01e0b91473e595ae936"
Unverified Commit 1e73adbc authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add optimized blockwise gemm using ck wrapper (#1157)



* Add optimized blockwise gemm using ck wrapper

* Add basic gemm example

* Update docs

* Add tutorial for gemm using ck wrapper

* Add perf note

* edits

* Fix cmake

* Fixes

---------
Co-authored-by: default avatarLisa Delaney <lisa.delaney@amd.com>
parent bf98b476
...@@ -2,3 +2,11 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap ...@@ -2,3 +2,11 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_img2col wrapper_img2col.cpp) add_executable(client_wrapper_img2col wrapper_img2col.cpp)
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942")
add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp)
target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp)
target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations)
endif()
# Composable Kernel wrapper GEMM tutorial
This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK)
wrapper. We present the base version of GEMM without most of the available optimizations; however,
it's worth noting that CK has kernels with different optimizations.
To implement these optimizations, you can use the CK wrapper or directly use available instances in
CK. You can also refer to the
[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp),
that uses CK wrapper based on the
[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation.
The kernel definition should look similar to:
```cpp
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
```
We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass
selected lengths of processed data through each block (`tile_shape`) and thread layout
(`thread_layout`). For compilation time parameters, we define the data type,
[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp)
and scalar per vector value during copy.
Step 1: Create layouts for global and LDS memory.
```cpp
// Specify layouts for global memory.
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
// Specify layouts for tiles.
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto c_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
// Apply padding for global memory.
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
```
We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or
`KPerBlock`.
Step 2: Create tensors for global and LDS memory.
```cpp
// Make tensors for global memory.
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_global_layout_padded);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_global_layout_padded);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_global_layout_padded);
// Allocate LDS memory.
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
// Make tensors for lds memory.
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
```
We must specify parameters for copy and convert block indexes to tuple:
```cpp
// Specify block index as tuple.
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
// Specify access parameters for copy.
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
constexpr ck::index_t vector_dim = 1;
```
We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also
define and clear an output register (`c_vgpr_reg`) for the accumulation.
```cpp
auto c_global_local_tile = ck::wrapper::make_local_tile(
c_global_tensor,
tile_shape,
block_idxs,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
// Create C vgpr to accumulate results.
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
// Clear C vgpr.
ck::wrapper::clear(c_vgpr_reg);
```
We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and
`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output
and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only
generic functions for the CK wrapper.
Step 3: Create the compute loop.
```cpp
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
ck::index_t i = 0;
do
{
// Get KPerBlock slice.
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
// Create local tiles for A and B.
auto a_global_local_tile = ck::wrapper::make_local_tile(
a_global_tensor_k_slice,
tile_shape,
block_idxs,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
auto b_global_local_tile = ck::wrapper::make_local_tile(
b_global_tensor_k_slice,
tile_shape,
block_idxs,
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
// Copy from global to LDS.
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_tile, a_lds_tensor, thread_layout);
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_tile, b_lds_tensor, thread_layout);
// Synchronize lds.
ck::block_sync_lds();
// Execute blockwise GEMM.
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
++i;
} while(i < num_loop);
```
Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block),
data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM
operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`.
The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread):
```cpp
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
```
If you want to dive deep into the details, you can find the entire example
[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp).
...@@ -6,13 +6,9 @@ ...@@ -6,13 +6,9 @@
#include <iostream> #include <iostream>
#include <initializer_list> #include <initializer_list>
#include <vector> #include <vector>
#include <gtest/gtest.h>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/device_memory.hpp"
...@@ -23,94 +19,88 @@ ...@@ -23,94 +19,88 @@
#include "ck/wrapper/tensor.hpp" #include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp" #include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
template <typename DataType> struct SimpleDeviceMem
void CheckResult(const std::vector<DataType>& a_data,
const std::vector<DataType>& b_data,
std::vector<DataType>& c_m_n_device_result,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K)
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough; SimpleDeviceMem() = delete;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K})); SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K})); {
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N})); (void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
a_m_k.mData = a_data; void* GetDeviceBuffer() { return p_mem_; }
b_k_n.mData = b_data;
auto ref_op = ReferenceGemmInstance{}; ~SimpleDeviceMem() { (void)hipFree(p_mem_); }
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument); void* p_mem_;
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData)); };
}
template <typename DataType, template <typename DataType,
typename GemmTraits, typename GemmTraits,
ck::index_t scalar_per_vector, ck::index_t scalar_per_vector,
typename BlockShape, typename BlockShape,
typename ThreadLayoutShape> typename ThreadLayout>
__global__ void DeviceGemm(const void* p_a, __global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
const ck::index_t M, const ck::index_t M,
const ck::index_t N, const ck::index_t N,
const ck::index_t K, const ck::index_t K,
const BlockShape tile_shape, const BlockShape tile_shape,
const ThreadLayoutShape thread_layout) const ThreadLayout thread_layout)
{ {
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
// Specify layouts for global memory.
const auto a_global_layout = const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout = const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout = const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
// Specify layouts for tiles.
constexpr auto a_tile_layout = ck::wrapper::make_layout( constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout( constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto c_tile_layout = ck::wrapper::make_layout( constexpr auto c_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
// Apply padding for global memory.
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
// Make tensors for global memory.
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>( auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_global_layout); static_cast<const DataType*>(p_a), a_global_layout_padded);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>( auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_global_layout); static_cast<const DataType*>(p_b), b_global_layout_padded);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>( auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_global_layout); static_cast<DataType*>(p_c), c_global_layout_padded);
// Allocate lds memory.
auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout));
auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout));
auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
// Make tensors for lds memory.
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>( auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout); static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>( auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout); static_cast<DataType*>(lds_b), b_tile_layout);
// Specify block index as tuple.
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x); const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
// Specify access parameters for copy.
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>; using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
constexpr ck::index_t vector_dim = 1; constexpr ck::index_t vector_dim = 1;
// Create tile and partition for C. Use specific function for blockwise_gemm to assign the
// appropriate partitions.
auto c_global_local_tile = ck::wrapper::make_local_tile( auto c_global_local_tile = ck::wrapper::make_local_tile(
c_padded_global_tensor, c_global_tensor,
tile_shape, tile_shape,
block_idx, block_idxs,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
auto c_global_local_partition = auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType, ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
...@@ -118,42 +108,49 @@ __global__ void DeviceGemm(const void* p_a, ...@@ -118,42 +108,49 @@ __global__ void DeviceGemm(const void* p_a,
decltype(b_tile_layout), decltype(b_tile_layout),
ck::wrapper::size(thread_layout), ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile); GemmTraits>(c_global_local_tile);
// Create C vgpr to accumulate results.
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType, auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout), decltype(a_tile_layout),
decltype(b_tile_layout), decltype(b_tile_layout),
ck::wrapper::size(thread_layout), ck::wrapper::size(thread_layout),
GemmTraits>(); GemmTraits>();
// Clear C vgpr.
ck::wrapper::clear(c_vgpr_reg); ck::wrapper::clear(c_vgpr_reg);
// Iterate over K with KPerBlock step.
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
ck::index_t i = 0; ck::index_t i = 0;
do do
{ {
// Get KPerBlock slice.
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice); auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice); auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
// Create local tiles for A and B.
auto a_global_local_tile = ck::wrapper::make_local_tile( auto a_global_local_tile = ck::wrapper::make_local_tile(
a_padded_global_tensor_k_slice, a_global_tensor_k_slice,
tile_shape, tile_shape,
block_idx, block_idxs,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
auto b_global_local_tile = ck::wrapper::make_local_tile( auto b_global_local_tile = ck::wrapper::make_local_tile(
b_padded_global_tensor_k_slice, b_global_tensor_k_slice,
tile_shape, tile_shape,
block_idx, block_idxs,
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
// Copy from global to lds.
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>( ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_tile, a_lds_tensor, thread_layout); a_global_local_tile, a_lds_tensor, thread_layout);
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>( ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_tile, b_lds_tensor, thread_layout); b_global_local_tile, b_lds_tensor, thread_layout);
// Synchronize lds.
ck::block_sync_lds(); ck::block_sync_lds();
// Execute blockwise gemm.
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>( ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg); a_lds_tensor, b_lds_tensor, c_vgpr_reg);
++i; ++i;
} while(i < num_loop); } while(i < num_loop);
// Copy vgpr results to C global memory.
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
} }
...@@ -161,36 +158,28 @@ template <typename DataType, ...@@ -161,36 +158,28 @@ template <typename DataType,
typename GemmTraits, typename GemmTraits,
ck::index_t scalar_per_vector, ck::index_t scalar_per_vector,
typename BlockShape, typename BlockShape,
typename ThreadLayoutShape> typename ThreadLayout>
void PerformGemm(const ck::index_t M, void PerformGemm(const ck::index_t M,
const ck::index_t N, const ck::index_t N,
const ck::index_t K, const ck::index_t K,
const BlockShape& tile_shape, const BlockShape& tile_shape,
const ThreadLayoutShape& thread_layout) const ThreadLayout& thread_layout)
{ {
// Global memory buffers // Global memory buffers
DeviceMem a_mem(M * K * sizeof(DataType)); SimpleDeviceMem a_mem(M * K * sizeof(DataType));
DeviceMem b_mem(K * N * sizeof(DataType)); SimpleDeviceMem b_mem(K * N * sizeof(DataType));
DeviceMem c_mem(M * N * sizeof(DataType)); SimpleDeviceMem c_mem(M * N * sizeof(DataType));
std::vector<DataType> a_data(M * K);
std::vector<DataType> b_data(K * N);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
a_mem.ToDevice(a_data.data()); const ck::index_t grid_size_x =
b_mem.ToDevice(b_data.data()); ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
c_mem.SetZero(); const ck::index_t grid_size_y =
const ck::index_t grid_size =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) *
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel = const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayoutShape>; DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout>;
launch_and_time_kernel(StreamConfig{nullptr}, const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel, kernel,
dim3(grid_size), dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)), dim3(ck::wrapper::size(thread_layout)),
0, 0,
a_mem.GetDeviceBuffer(), a_mem.GetDeviceBuffer(),
...@@ -202,56 +191,26 @@ void PerformGemm(const ck::index_t M, ...@@ -202,56 +191,26 @@ void PerformGemm(const ck::index_t M,
tile_shape, tile_shape,
thread_layout); thread_layout);
std::vector<DataType> c_data(M * N); std::size_t flop = std::size_t(2) * M * N * K;
c_mem.FromDevice(c_data.data()); std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
}
TEST(TestGemm, Float) float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
{ float gb_per_sec = num_btype / 1.E6 / avg_time;
using DataType = float;
const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Int8) std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
{ << gb_per_sec << " GB/s, " << std::endl;
using DataType = int8_t;
const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 16>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1>(
129, 129, 67, tile_shape, thread_layout);
} }
TEST(TestGemm, Half) int main(int argc, char* argv[])
{ {
using DataType = ck::half_t; using DataType = ck::half_t;
const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); const auto thread_layout =
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8>( ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}));
512, 512, 128, tile_shape, thread_layout); const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
// Irregular case PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8>(
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1>( 3840, 4096, 4096, tile_shape, thread_layout);
129, 129, 67, tile_shape, thread_layout); return 0;
}
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
{
using DataType = float;
const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{});
const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave);
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave);
} }
// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s,
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/wrapper/layout.hpp" #include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp" #include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp" #include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
static constexpr ck::index_t NumDimSpatial = 3; static constexpr ck::index_t NumDimSpatial = 3;
using DataType = float; using DataType = float;
...@@ -36,21 +37,20 @@ struct SimpleDeviceMem ...@@ -36,21 +37,20 @@ struct SimpleDeviceMem
void* p_mem_; void* p_mem_;
}; };
// Test copy from Global to Global through LDS and VGPR template <typename InputTensor, typename OutputTensor, typename BlockShape, typename ThreadLayout>
template <typename InputTensor, __global__ void __CK_WRAPPER_LAUNCH_BOUNDS__
typename OutputTensor, DeviceImageToColumnPad0(InputTensor input_tensor,
typename BlockShape,
typename ThreadLayoutShape>
__global__ void DeviceImageToColumnPad0(InputTensor input_tensor,
OutputTensor output_tensor, OutputTensor output_tensor,
const BlockShape tile_shape, const BlockShape tile_shape,
const ThreadLayoutShape thread_layout) const ThreadLayout thread_layout)
{ {
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x); // grid layout (dim1, dim0)
const auto block_idxs =
ck::make_tuple(static_cast<ck::index_t>(blockIdx.y), static_cast<ck::index_t>(blockIdx.x));
// Get local tiles for global memory // Get local tiles for global memory
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
// Get partition per thread // Get partition per thread
const auto input_local_partition = const auto input_local_partition =
...@@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G, ...@@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType)); SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType));
// User can choose appropriate number of threads and sizes per block // User can choose appropriate number of threads and sizes per block
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}),
ck::make_tuple(ck::Number<16>{}, ck::Number<1>{}));
// This example doesn't support padding, user should select tile sizes // This example doesn't support padding, user should select tile sizes
// which divides the shape completely // which are divisible by the shape.
const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{}); const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{});
// Create buffers for global memory // Create buffers for global memory
...@@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G, ...@@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>( auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(out_buf.GetDeviceBuffer()), out_layout); static_cast<DataType*>(out_buf.GetDeviceBuffer()), out_layout);
const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), // grid layout (dim1, dim0)
ck::wrapper::size<0>(tile_shape)) * const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
ck::wrapper::size<1>(tile_shape)); ck::wrapper::size<1>(tile_shape));
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
ck::wrapper::size<0>(tile_shape));
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global), const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
decltype(output_tensor_global), decltype(output_tensor_global),
...@@ -134,7 +137,7 @@ void PerformImageToColumnPad0(const ck::index_t G, ...@@ -134,7 +137,7 @@ void PerformImageToColumnPad0(const ck::index_t G,
decltype(thread_layout)>; decltype(thread_layout)>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel, kernel,
dim3(grid_size), dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)), dim3(ck::wrapper::size(thread_layout)),
0, 0,
input_tensor_global, input_tensor_global,
...@@ -178,3 +181,4 @@ int main(int argc, char* argv[]) ...@@ -178,3 +181,4 @@ int main(int argc, char* argv[])
{1, 1, 1} /*filter_dilations*/); {1, 1, 1} /*filter_dilations*/);
return 0; return 0;
} }
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <bool DoPad, typename Layout, typename PaddingDims>
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
{
if constexpr(DoPad)
{
return ck::wrapper::pad(layout, padding_dims);
}
else
{
return layout;
}
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout,
bool DoPadding>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
constexpr auto K1 = GemmTraits::K1;
constexpr auto K0PerBlock = KPerBlock / K1;
const auto K0 = ck::math::integer_divide_ceil(K, K1);
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
// Create layouts for global memory
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
// Apply padding
auto a_padded_global_layout =
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
auto b_padded_global_layout =
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
auto c_padded_global_layout =
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
// Reshape from M,K to K0,M,K1
const auto reshaped_dims_idxs =
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
auto a_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto b_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
// Create tensors for global memory
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_padded_global_layout);
// Create layouts and tensors for lds memory.
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, MPerBlock, K1),
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, NPerBlock, K1),
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock];
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
constexpr ck::index_t vector_dim = 2;
// Create tile and partition for C global memory. Use specific gemm
// functions to get appropriate layouts.
auto c_global_local_tile =
ck::wrapper::make_local_tile(c_global_tensor,
tile_shape_k0_m_n_k1,
block_idxs,
make_tuple(ck::wrapper::slice(K0PerBlock),
ck::Number<1>{},
ck::Number<1>{},
ck::wrapper::slice(K1)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
// Define and clear c vgpr register
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
ck::wrapper::clear(c_vgpr_reg);
// Local partitions for lds memory
auto a_lds_tensor_local_partition =
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
auto b_lds_tensor_local_partition =
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
// Lamda to slice tensor, then create local tile and partition
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
const auto k_slice =
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
ck::wrapper::slice(),
ck::wrapper::slice());
auto local_tile = ck::wrapper::make_local_tile(
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
};
auto a_global_local_partition = make_global_partition(
a_global_tensor,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
0);
auto b_global_local_partition = make_global_partition(
b_global_tensor,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
0);
// (row-major vgpr layout)
auto a_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(a_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
auto b_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(b_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
// Copy first values to lds
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
a_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
b_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
b_lds_tensor_local_partition);
// Pipeline loop
const ck::index_t num_loop =
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
// Skip if only tile should be processed
if(num_loop > 1)
{
ck::index_t i = 0;
do
{
auto a_global_local_partition_i = make_global_partition(
a_global_tensor,
make_tuple(
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
i + 1);
auto b_global_local_partition_i = make_global_partition(
b_global_tensor,
make_tuple(
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
i + 1);
// Copy data to A vgpr.
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_partition_i, a_vgpr_tensor);
// Synchronize.
ck::block_sync_lds();
// Copy data to B vgpr.
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_partition_i, b_vgpr_tensor);
// Perform gemm.
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
// Synchronize
ck::block_sync_lds();
// Copy data to A and B lds tiles.
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_vgpr_tensor, a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_vgpr_tensor, b_lds_tensor_local_partition);
++i;
} while(i < (num_loop - 1));
}
// Handle tail.
ck::block_sync_lds();
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
// Store data from C vgpr to C global memory.
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
bool DoPadding,
typename BlockShape,
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayout& thread_layout)
{
// Global memory buffers
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[])
{
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8, false>(
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}
// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s,
...@@ -12,10 +12,6 @@ Wrapper ...@@ -12,10 +12,6 @@ Wrapper
Description Description
------------------------------------- -------------------------------------
.. note::
The wrapper is under development and its functionality is limited.
The CK library provides a lightweight wrapper for more complex operations implemented in The CK library provides a lightweight wrapper for more complex operations implemented in
the library. the library.
...@@ -54,9 +50,15 @@ Output:: ...@@ -54,9 +50,15 @@ Output::
2 6 10 14 18 22 26 30 2 6 10 14 18 22 26 30
Tutorials:
* `GEMM tutorial <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/README.md>`_
Advanced examples: Advanced examples:
* `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_ * `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_
* `Basic gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp>`_
* `Optimized gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp>`_
------------------------------------- -------------------------------------
Layout Layout
......
...@@ -61,8 +61,8 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) ...@@ -61,8 +61,8 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
decltype(dim_access_order), decltype(dim_access_order),
VectorDim, VectorDim,
ScalarPerVector, ScalarPerVector,
Sequence<false>, Sequence<true>,
Sequence<false>>{in_grid_desc, Sequence<true>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()), make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc, out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()), make_tuple(dst_tensor.GetMultiIdxOffsets()),
...@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) ...@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer) else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{ {
// Perform copy from DynamicBuffer to StaticBuffer // Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin = const auto dst_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{}); generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2( auto transfer = ThreadwiseTensorSliceTransfer_v2<
[&](auto I) { std::remove_const_t<typename SrcTensorType::TensorElementType>,
if constexpr(I == VectorDim) std::remove_const_t<typename DstTensorType::TensorElementType>,
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>, remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>, remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths), decltype(thread_slice_lengths),
decltype(dim_access_order), decltype(dim_access_order),
decltype(src_vector_tensor_lengths), VectorDim,
decltype(dim_access_order)>{ ScalarPerVector,
src_tensor.GetMultiIdxOffsets()}; I1,
false,
false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc, transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(), src_tensor.GetBuffer(),
out_grid_desc, out_grid_desc,
src_dst_slice_origin, dst_slice_origin_idxs,
dst_tensor.GetBuffer()); dst_tensor.GetBuffer());
} }
else else
...@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple, ...@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple,
index_t ScalarPerVector, index_t ScalarPerVector,
typename SrcTensorType, typename SrcTensorType,
typename DstTensorType, typename DstTensorType,
typename ThreadLayoutTuple> typename ThreadShape,
__device__ void blockwise_copy(const SrcTensorType& src_tensor, typename ThreadUnrolledDesc>
__device__ void
blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor, DstTensorType& dst_tensor,
[[maybe_unused]] ThreadLayoutTuple& thread_layout) [[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
{ {
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer); static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value); static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
...@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor, ...@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
constexpr auto tile_lengths_seq = constexpr auto tile_lengths_seq =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{}); generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq = generate_sequence_v2( constexpr auto thread_layout_seq =
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{}); generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2( constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{}); [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>; using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
// Perform copy between DynamicBuffers // Perform copy between DynamicBuffers
auto transfer = ThreadGroupTensorSliceTransfer_v7< auto transfer = ThreadGroupTensorSliceTransfer_v7<
......
...@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor() ...@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor()
/** /**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be * \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
* data layout must be (NPerBlock, KPerBlock). * (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
* or (K0PerBlock, NPerBlock, K1).
* *
* \note C output Vgpr register layout (8D): * \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M * - MXdlPerWave - The number of MFMA instructions run by single wave in M
...@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor() ...@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor()
* \tparam BlockSize Tensor to pad. * \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation. * \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm * \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout. * (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm * \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout. * (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm. * \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/ */
template <typename DataType, template <typename DataType,
...@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, ...@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const BTensorType& b_local_tile_tensor, const BTensorType& b_local_tile_tensor,
CTensorType& c_reg_tensor) CTensorType& c_reg_tensor)
{ {
constexpr auto I3 = Number<3>{};
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr); static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
...@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, ...@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>; using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>; using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type = using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>()); conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type = using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>()); conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType, DataType,
...@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor) ...@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{}; constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer = constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>; is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>; using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type = using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>()); conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type = using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>()); conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops = using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
...@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor) ...@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor()); layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>( const auto lower_upper_dims =
partition_shape, partition_desc); generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout); auto sliced_desc = transform_tensor_descriptor(
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0], partition_desc,
make_tuple(
make_slice_transform(partition_shape.At(Number<0>{}),
m_thread_data_on_grid_idx[I0],
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<1>{}),
n_thread_data_on_grid_idx[I0], n_thread_data_on_grid_idx[I0],
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<2>{}),
m_thread_data_on_grid_idx[I1], m_thread_data_on_grid_idx[I1],
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<3>{}),
n_thread_data_on_grid_idx[I1], n_thread_data_on_grid_idx[I1],
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<4>{}),
m_thread_data_on_grid_idx[I2], m_thread_data_on_grid_idx[I2],
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
make_slice_transform(partition_shape.At(Number<5>{}),
m_thread_data_on_grid_idx[I3], m_thread_data_on_grid_idx[I3],
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
make_slice_transform(partition_shape.At(Number<6>{}),
m_thread_data_on_grid_idx[I4], m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2])); partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
make_slice_transform(partition_shape.At(Number<7>{}),
n_thread_data_on_grid_idx[I2],
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
lower_upper_dims,
lower_upper_dims);
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, sliced_desc);
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout);
return partition_tensor; return partition_tensor;
} }
...@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() ...@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
constexpr auto I6 = Number<6>{}; constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{}; constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer = constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>; is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>; using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type = using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>()); conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type = using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>()); conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops = using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize, BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
...@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() ...@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>( const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
vgpr_shape, vgpr_desc); vgpr_shape, vgpr_desc);
// Get vector type for Vgpr // Get vector type for Vgpr
using BlockwiseGemmCThreadBufferType = constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>; using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>( return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
vgpr_layout); vgpr_layout);
} }
......
...@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>& ...@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
} }
} }
template <typename... Ts, typename Shape, typename FlattenDescriptor> template <typename... Ts, typename Shape, typename UnrolledDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx, __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape, const Shape& shape,
const FlattenDescriptor& flatten_desc) const UnrolledDescriptor& flatten_desc)
{ {
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
......
...@@ -20,48 +20,57 @@ namespace wrapper { ...@@ -20,48 +20,57 @@ namespace wrapper {
* \tparam K1Value The number of K-dim elements that are packed together as * \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size. * a separate logical dimension. Usually aligns with vector load size.
*/ */
template <index_t MPerXDLValue, template <typename MPerXDLValue,
index_t NPerXDLValue, typename NPerXDLValue,
index_t MXdlPerWaveValue, typename MXdlPerWaveValue,
index_t NXdlPerWaveValue, typename NXdlPerWaveValue,
index_t K1Value> typename K1Value>
struct BlockwisGemmXdlTraits struct BlockwisGemmXdlTraits
{ {
static constexpr index_t MPerXDL = MPerXDLValue; static constexpr auto MPerXDL = MPerXDLValue{};
static constexpr index_t NPerXDL = NPerXDLValue; static constexpr auto NPerXDL = NPerXDLValue{};
static constexpr index_t MXdlPerWave = MXdlPerWaveValue; static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
static constexpr index_t NXdlPerWave = NXdlPerWaveValue; static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
static constexpr index_t K1 = K1Value; static constexpr auto K1 = K1Value{};
}; };
// K1 = 4 // K1 = 4
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4> struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
{ {
}; };
// K1 = 8 // K1 = 8
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8> struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8> struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8> struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
{ {
}; };
// K1 = 16 // K1 = 16
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16> struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16> struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
{ {
}; };
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16> struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
{ {
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace wrapper {
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
} // namespace wrapper
} // namespace ck
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck { namespace ck {
namespace wrapper { namespace wrapper {
...@@ -29,6 +30,7 @@ template <typename T> ...@@ -29,6 +30,7 @@ template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace { namespace {
namespace detail {
/** /**
* \brief Generate packed (column-major) strides if not passed * \brief Generate packed (column-major) strides if not passed
* *
...@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha ...@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
} }
} }
} // namespace detail
} // namespace } // namespace
/// @endcond /// @endcond
...@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha ...@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
template <typename Shape, typename Strides> template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{ {
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{})); using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides)); return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, strides));
} }
/** /**
...@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides ...@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape> template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape) __host__ __device__ constexpr auto make_layout(const Shape& shape)
{ {
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{})); using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{})); return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
} }
// Layout helpers // Layout helpers
// get // get
/** /**
* \private * \private
* \brief Get dim. * \brief Get dim.
...@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple) ...@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout. * \param layout Layout to create sub layout.
* \return Requsted sub layout. * \return Requsted sub layout.
*/ */
template <index_t idx, typename Shape, typename FlattenDesc> template <index_t idx, typename Shape, typename UnrolledDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout) __host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
{ {
const auto& shape = layout.GetShape(); const auto& shape = layout.GetShape();
const auto new_shape = get<idx>(shape); const auto new_shape = get<idx>(shape);
...@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout) ...@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
return layout.GetShape(); return layout.GetShape();
} }
// pad
/**
* \brief Pad layout shapes to be adjusted to tile lengths.
*
*
* \param layout Layout to pad.
* \param tile_lengths Tile lengths to align layout shape.
* \return Padded layout.
*/
template <typename Shape, typename UnrolledDesc, typename TileLengths>
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
const TileLengths& tile_lengths)
{
auto& unrolled_desc = layout.GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
// Create layout
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
}
// unmerge
/**
* \brief Unmerge selected dim in layout.
*
* \tparam Idx Index to dimension being unmerged.
* \param layout Layout to pad.
* \param new_lengths Dimensions into which the indicated dimension will be divided.
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
* \return Unmerged layout.
*/
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
const NewLengths& new_lengths,
[[maybe_unused]] const NewIdxs& new_indexes)
{
const auto& layout_shape = shape(layout);
auto& unrolled_desc = layout.GetUnrolledDescriptor();
constexpr auto dims = Shape::Size();
// Generate transforms
const auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == Idx)
{
return make_unmerge_transform(new_lengths);
}
else
{
return make_pass_through_transform(layout_shape.At(i));
}
},
Number<dims>{});
constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
{
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
return to_sequence(idxs_tuple);
}
else
{
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
return Sequence<index>{};
}
},
Number<dims>{});
const auto unmerged_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
const auto unmerged_shape =
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
// Create layout
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
}
} // namespace wrapper } // namespace wrapper
} // namespace ck } // namespace ck
add_gtest_executable(test_layout test_layout.cpp) add_custom_target(test_wrapper)
target_link_libraries(test_layout PRIVATE utility)
add_gtest_executable(test_tensor test_tensor.cpp) add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp)
target_link_libraries(test_tensor PRIVATE utility) target_link_libraries(test_wrapper_layout PRIVATE utility)
add_gtest_executable(test_copy test_copy.cpp) add_dependencies(test_wrapper test_wrapper_layout)
target_link_libraries(test_copy PRIVATE utility) add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp)
add_gtest_executable(test_partition test_partition.cpp) target_link_libraries(test_wrapper_tensor PRIVATE utility)
target_link_libraries(test_partition PRIVATE utility) add_dependencies(test_wrapper test_wrapper_tensor)
add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp)
target_link_libraries(test_wrapper_copy PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_copy)
add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp)
target_link_libraries(test_wrapper_partition PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_partition)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942") GPU_TARGETS MATCHES "gfx942")
add_gtest_executable(test_gemm test_gemm.cpp) add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp)
target_link_libraries(test_gemm PRIVATE utility) target_link_libraries(test_wrapper_gemm PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_gemm)
endif() endif()
...@@ -20,23 +20,25 @@ ...@@ -20,23 +20,25 @@
template <typename InputTensor, template <typename InputTensor,
typename OutputTensor, typename OutputTensor,
typename BlockShape, typename BlockShape,
typename ThreadLayoutShape, typename ThreadLayout,
bool UseOptimizedCopy> bool UseOptimizedCopy>
__global__ void TestCopyDevice(const InputTensor input_tensor, __global__ void TestCopyDevice(const InputTensor input_tensor,
OutputTensor output_tensor, OutputTensor output_tensor,
const BlockShape tile_shape, const BlockShape tile_shape,
const ThreadLayoutShape thread_layout) const ThreadLayout thread_layout)
{ {
__shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)]; __shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)];
const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>( const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
p_shared, ck::wrapper::make_layout(tile_shape)); p_shared, ck::wrapper::make_layout(tile_shape));
const auto block_idx = static_cast<ck::index_t>(blockIdx.x); const auto block_idxs =
ck::make_tuple(static_cast<ck::index_t>(blockIdx.x), static_cast<ck::index_t>(blockIdx.y));
// Get local tiles for global memory // Get local tiles for global memory
const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); const auto input_local_tile =
ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
const auto output_local_tile = const auto output_local_tile =
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
// Get partition per thread // Get partition per thread
const auto input_local_partition = const auto input_local_partition =
...@@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor, ...@@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor,
// Allocate VGPR // Allocate VGPR
auto tensor_vgpr = auto tensor_vgpr =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>( ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
layout(lds_local_partition)); ck::wrapper::make_layout(shape(lds_local_partition)));
// Perform copy // Perform copy
if constexpr(UseOptimizedCopy) if constexpr(UseOptimizedCopy)
...@@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS() ...@@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS()
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>( auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout); static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout);
const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}); const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}));
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}); const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const ck::index_t grid_size = ck::math::integer_divide_ceil( const ck::index_t grid_size_x = ck::math::integer_divide_ceil(
ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape)); ck::wrapper::size<0>(input_tensor_global), ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(
ck::wrapper::size<1>(input_tensor_global), ck::wrapper::size<1>(tile_shape));
const auto kernel = TestCopyDevice<decltype(input_tensor_global), const auto kernel = TestCopyDevice<decltype(input_tensor_global),
decltype(output_tensor_global), decltype(output_tensor_global),
...@@ -112,7 +117,7 @@ void PerformCopyGlobalToGlobalViaLDS() ...@@ -112,7 +117,7 @@ void PerformCopyGlobalToGlobalViaLDS()
UseOptimizedCopy>; UseOptimizedCopy>;
launch_and_time_kernel(StreamConfig{}, launch_and_time_kernel(StreamConfig{},
kernel, kernel,
dim3(grid_size), dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)), dim3(ck::wrapper::size(thread_layout)),
0, 0,
input_tensor_global, input_tensor_global,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
template <typename DataType>
void CheckResult(const std::vector<DataType>& a_data,
const std::vector<DataType>& b_data,
std::vector<DataType>& c_m_n_device_result,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
a_m_k.mData = a_data;
b_k_n.mData = b_data;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
}
template <bool DoPad, typename Layout, typename PaddingDims>
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
{
if constexpr(DoPad)
{
return ck::wrapper::pad(layout, padding_dims);
}
else
{
return layout;
}
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout,
bool DoPadding>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
constexpr auto K1 = GemmTraits::K1;
constexpr auto K0PerBlock = KPerBlock / K1;
const auto K0 = ck::math::integer_divide_ceil(K, K1);
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
auto a_padded_global_layout =
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
auto b_padded_global_layout =
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
auto c_padded_global_layout =
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
// Reshape from M,K to K0,M,K1
const auto reshaped_dims_idxs =
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
auto a_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto b_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_padded_global_layout);
// Add extra M and N
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, MPerBlock, K1),
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, NPerBlock, K1),
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + NPerBlock];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + NPerBlock];
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
constexpr ck::index_t vector_dim = 2;
auto c_global_local_tile =
ck::wrapper::make_local_tile(c_global_tensor,
tile_shape_k0_m_n_k1,
block_idxs,
make_tuple(ck::wrapper::slice(K0PerBlock),
ck::Number<1>{},
ck::Number<1>{},
ck::wrapper::slice(K1)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
ck::wrapper::clear(c_vgpr_reg);
auto a_lds_tensor_local_partition =
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
auto b_lds_tensor_local_partition =
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
const auto k_slice =
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
ck::wrapper::slice(),
ck::wrapper::slice());
auto local_tile = ck::wrapper::make_local_tile(
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
};
auto a_global_local_partition = make_global_partition(
a_global_tensor,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
0);
auto b_global_local_partition = make_global_partition(
b_global_tensor,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
0);
// (row-major vgpr layout)
auto a_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(a_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
auto b_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(b_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
a_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
b_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
b_lds_tensor_local_partition);
const ck::index_t num_loop =
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
if(num_loop > 1)
{
ck::index_t i = 0;
do
{
auto a_global_local_partition_i = make_global_partition(
a_global_tensor,
make_tuple(
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
i + 1);
auto b_global_local_partition_i = make_global_partition(
b_global_tensor,
make_tuple(
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
i + 1);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_partition_i, a_vgpr_tensor);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_partition_i, b_vgpr_tensor);
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_vgpr_tensor, a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_vgpr_tensor, b_lds_tensor_local_partition);
++i;
} while(i < (num_loop - 1));
}
ck::block_sync_lds();
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
bool DoPadding,
typename BlockShape,
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayout& thread_layout)
{
// Global memory buffers
DeviceMem a_mem(M * K * sizeof(DataType));
DeviceMem b_mem(K * N * sizeof(DataType));
DeviceMem c_mem(M * N * sizeof(DataType));
std::vector<DataType> a_data(M * K);
std::vector<DataType> b_data(K * N);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
a_mem.ToDevice(a_data.data());
b_mem.ToDevice(b_data.data());
c_mem.SetZero();
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
std::vector<DataType> c_data(M * N);
c_mem.FromDevice(c_data.data());
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
}
TEST(TestGemm, Float)
{
using DataType = float;
// (dim1, dim2, dim0 thread layout)
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Int8)
{
using DataType = int8_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType,
ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1,
16,
false>(512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Half)
{
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
{
using DataType = float;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
......
...@@ -30,7 +30,10 @@ TEST(TestPartition, LocalPartition) ...@@ -30,7 +30,10 @@ TEST(TestPartition, LocalPartition)
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout); ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{}); const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}); // row-major thread layout
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}, ck::Number<1>{}));
// 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim) // 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim)
const auto thread_projection = const auto thread_projection =
ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{}); ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{});
...@@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile) ...@@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile)
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{}); ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{});
const auto block_projection = const auto block_projection =
ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2)); ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2));
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto num_blocks = const auto grid_shape =
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape)); ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape));
std::vector<ck::index_t> block_idxs(ck::wrapper::size(num_blocks)); std::vector<ck::Tuple<ck::index_t, ck::index_t, ck::index_t, ck::index_t>> block_idxs;
std::iota(block_idxs.begin(), block_idxs.end(), 0); for(int i = 0; i < ck::wrapper::size<0>(grid_shape); i++)
{
for(int j = 0; j < ck::wrapper::size<1>(grid_shape); j++)
{
for(int k = 0; k < ck::wrapper::size<2>(grid_shape); k++)
{
block_idxs.emplace_back(i, j, k, 0);
}
}
}
for(auto block_idx : block_idxs) for(auto block_idx : block_idxs)
{ {
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto packed_tile = const auto packed_tile =
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection); ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection);
const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim; const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim;
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * auto expected_tile_first_val = ck::wrapper::size<2>(block_idx) *
ck::wrapper::size<2>(block_shape) * ck::wrapper::size<2>(block_shape) *
ck::wrapper::size<2>(strides); ck::wrapper::size<2>(strides);
block_idx /= ck::wrapper::size<2>(num_blocks); expected_tile_first_val += ck::wrapper::size<1>(block_idx) *
expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) *
ck::wrapper::size<1>(block_shape) * ck::wrapper::size<1>(block_shape) *
ck::wrapper::size<1>(strides); ck::wrapper::size<1>(strides);
block_idx /= ck::wrapper::size<1>(num_blocks); expected_tile_first_val += ck::wrapper::size<0>(block_idx) *
expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) *
ck::wrapper::size<0>(block_shape) * ck::wrapper::size<0>(block_shape) *
ck::wrapper::size<0>(strides); ck::wrapper::size<0>(strides);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment