Unverified Commit 1e73adbc authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Add optimized blockwise gemm using ck wrapper (#1157)



* Add optimized blockwise gemm using ck wrapper

* Add basic gemm example

* Update docs

* Add tutorial for gemm using ck wrapper

* Add perf note

* edits

* Fix cmake

* Fixes

---------
Co-authored-by: default avatarLisa Delaney <lisa.delaney@amd.com>
parent bf98b476
......@@ -2,3 +2,11 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942")
add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp)
target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp)
target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations)
endif()
# Composable Kernel wrapper GEMM tutorial
This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK)
wrapper. We present the base version of GEMM without most of the available optimizations; however,
it's worth noting that CK has kernels with different optimizations.
To implement these optimizations, you can use the CK wrapper or directly use available instances in
CK. You can also refer to the
[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp),
that uses CK wrapper based on the
[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation.
The kernel definition should look similar to:
```cpp
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
```
We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass
selected lengths of processed data through each block (`tile_shape`) and thread layout
(`thread_layout`). For compilation time parameters, we define the data type,
[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp)
and scalar per vector value during copy.
Step 1: Create layouts for global and LDS memory.
```cpp
// Specify layouts for global memory.
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
// Specify layouts for tiles.
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto c_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
// Apply padding for global memory.
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
```
We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or
`KPerBlock`.
Step 2: Create tensors for global and LDS memory.
```cpp
// Make tensors for global memory.
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_global_layout_padded);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_global_layout_padded);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_global_layout_padded);
// Allocate LDS memory.
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
// Make tensors for lds memory.
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
```
We must specify parameters for copy and convert block indexes to tuple:
```cpp
// Specify block index as tuple.
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
// Specify access parameters for copy.
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
constexpr ck::index_t vector_dim = 1;
```
We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also
define and clear an output register (`c_vgpr_reg`) for the accumulation.
```cpp
auto c_global_local_tile = ck::wrapper::make_local_tile(
c_global_tensor,
tile_shape,
block_idxs,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
// Create C vgpr to accumulate results.
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
// Clear C vgpr.
ck::wrapper::clear(c_vgpr_reg);
```
We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and
`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output
and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only
generic functions for the CK wrapper.
Step 3: Create the compute loop.
```cpp
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
ck::index_t i = 0;
do
{
// Get KPerBlock slice.
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
// Create local tiles for A and B.
auto a_global_local_tile = ck::wrapper::make_local_tile(
a_global_tensor_k_slice,
tile_shape,
block_idxs,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
auto b_global_local_tile = ck::wrapper::make_local_tile(
b_global_tensor_k_slice,
tile_shape,
block_idxs,
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
// Copy from global to LDS.
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_tile, a_lds_tensor, thread_layout);
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_tile, b_lds_tensor, thread_layout);
// Synchronize lds.
ck::block_sync_lds();
// Execute blockwise GEMM.
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
++i;
} while(i < num_loop);
```
Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block),
data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM
operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`.
The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread):
```cpp
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
```
If you want to dive deep into the details, you can find the entire example
[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp).
......@@ -6,13 +6,9 @@
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
......@@ -23,94 +19,88 @@
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
template <typename DataType>
void CheckResult(const std::vector<DataType>& a_data,
const std::vector<DataType>& b_data,
std::vector<DataType>& c_m_n_device_result,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K)
struct SimpleDeviceMem
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
SimpleDeviceMem() = delete;
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
a_m_k.mData = a_data;
b_k_n.mData = b_data;
void* GetDeviceBuffer() { return p_mem_; }
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
ref_invoker.Run(ref_argument);
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
}
void* p_mem_;
};
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayoutShape>
__global__ void DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout)
typename ThreadLayout>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
// Specify layouts for global memory.
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
// Specify layouts for tiles.
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
constexpr auto c_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
// Apply padding for global memory.
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
// Make tensors for global memory.
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_global_layout);
static_cast<const DataType*>(p_a), a_global_layout_padded);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_global_layout);
static_cast<const DataType*>(p_b), b_global_layout_padded);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_global_layout);
auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout));
auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout));
auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout));
static_cast<DataType*>(p_c), c_global_layout_padded);
// Allocate lds memory.
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
// Make tensors for lds memory.
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
// Specify block index as tuple.
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
// Specify access parameters for copy.
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
constexpr ck::index_t vector_dim = 1;
// Create tile and partition for C. Use specific function for blockwise_gemm to assign the
// appropriate partitions.
auto c_global_local_tile = ck::wrapper::make_local_tile(
c_padded_global_tensor,
c_global_tensor,
tile_shape,
block_idx,
block_idxs,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
......@@ -118,42 +108,49 @@ __global__ void DeviceGemm(const void* p_a,
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
// Create C vgpr to accumulate results.
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
// Clear C vgpr.
ck::wrapper::clear(c_vgpr_reg);
// Iterate over K with KPerBlock step.
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
ck::index_t i = 0;
do
{
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice);
auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice);
auto a_global_local_tile = ck::wrapper::make_local_tile(
a_padded_global_tensor_k_slice,
// Get KPerBlock slice.
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
// Create local tiles for A and B.
auto a_global_local_tile = ck::wrapper::make_local_tile(
a_global_tensor_k_slice,
tile_shape,
block_idx,
block_idxs,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
auto b_global_local_tile = ck::wrapper::make_local_tile(
b_padded_global_tensor_k_slice,
b_global_tensor_k_slice,
tile_shape,
block_idx,
block_idxs,
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
// Copy from global to lds.
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_tile, a_lds_tensor, thread_layout);
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_tile, b_lds_tensor, thread_layout);
// Synchronize lds.
ck::block_sync_lds();
// Execute blockwise gemm.
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
++i;
} while(i < num_loop);
// Copy vgpr results to C global memory.
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
......@@ -161,97 +158,59 @@ template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayoutShape>
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayoutShape& thread_layout)
const ThreadLayout& thread_layout)
{
// Global memory buffers
DeviceMem a_mem(M * K * sizeof(DataType));
DeviceMem b_mem(K * N * sizeof(DataType));
DeviceMem c_mem(M * N * sizeof(DataType));
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
std::vector<DataType> a_data(M * K);
std::vector<DataType> b_data(K * N);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
a_mem.ToDevice(a_data.data());
b_mem.ToDevice(b_data.data());
c_mem.SetZero();
const ck::index_t grid_size =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) *
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayoutShape>;
launch_and_time_kernel(StreamConfig{nullptr},
kernel,
dim3(grid_size),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::vector<DataType> c_data(M * N);
c_mem.FromDevice(c_data.data());
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
}
TEST(TestGemm, Float)
{
using DataType = float;
const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Int8)
{
using DataType = int8_t;
const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 16>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Half)
{
using DataType = ck::half_t;
const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1>(
129, 129, 67, tile_shape, thread_layout);
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
}
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
int main(int argc, char* argv[])
{
using DataType = float;
const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{});
const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave);
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1, 4>(
512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave);
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8>(
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}
// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s,
......@@ -15,6 +15,7 @@
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
static constexpr ck::index_t NumDimSpatial = 3;
using DataType = float;
......@@ -36,21 +37,20 @@ struct SimpleDeviceMem
void* p_mem_;
};
// Test copy from Global to Global through LDS and VGPR
template <typename InputTensor,
typename OutputTensor,
typename BlockShape,
typename ThreadLayoutShape>
__global__ void DeviceImageToColumnPad0(InputTensor input_tensor,
OutputTensor output_tensor,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout)
template <typename InputTensor, typename OutputTensor, typename BlockShape, typename ThreadLayout>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__
DeviceImageToColumnPad0(InputTensor input_tensor,
OutputTensor output_tensor,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
// grid layout (dim1, dim0)
const auto block_idxs =
ck::make_tuple(static_cast<ck::index_t>(blockIdx.y), static_cast<ck::index_t>(blockIdx.x));
// Get local tiles for global memory
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
// Get partition per thread
const auto input_local_partition =
......@@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType));
// User can choose appropriate number of threads and sizes per block
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}),
ck::make_tuple(ck::Number<16>{}, ck::Number<1>{}));
// This example doesn't support padding, user should select tile sizes
// which divides the shape completely
// which are divisible by the shape.
const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{});
// Create buffers for global memory
......@@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(out_buf.GetDeviceBuffer()), out_layout);
const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
ck::wrapper::size<0>(tile_shape)) *
ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
ck::wrapper::size<1>(tile_shape));
// grid layout (dim1, dim0)
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
ck::wrapper::size<1>(tile_shape));
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
ck::wrapper::size<0>(tile_shape));
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
decltype(output_tensor_global),
......@@ -134,7 +137,7 @@ void PerformImageToColumnPad0(const ck::index_t G,
decltype(thread_layout)>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size),
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
input_tensor_global,
......@@ -178,3 +181,4 @@ int main(int argc, char* argv[])
{1, 1, 1} /*filter_dilations*/);
return 0;
}
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <bool DoPad, typename Layout, typename PaddingDims>
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
{
if constexpr(DoPad)
{
return ck::wrapper::pad(layout, padding_dims);
}
else
{
return layout;
}
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout,
bool DoPadding>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
constexpr auto K1 = GemmTraits::K1;
constexpr auto K0PerBlock = KPerBlock / K1;
const auto K0 = ck::math::integer_divide_ceil(K, K1);
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
// Create layouts for global memory
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
// Apply padding
auto a_padded_global_layout =
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
auto b_padded_global_layout =
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
auto c_padded_global_layout =
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
// Reshape from M,K to K0,M,K1
const auto reshaped_dims_idxs =
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
auto a_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto b_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
// Create tensors for global memory
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_padded_global_layout);
// Create layouts and tensors for lds memory.
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, MPerBlock, K1),
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, NPerBlock, K1),
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock];
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
constexpr ck::index_t vector_dim = 2;
// Create tile and partition for C global memory. Use specific gemm
// functions to get appropriate layouts.
auto c_global_local_tile =
ck::wrapper::make_local_tile(c_global_tensor,
tile_shape_k0_m_n_k1,
block_idxs,
make_tuple(ck::wrapper::slice(K0PerBlock),
ck::Number<1>{},
ck::Number<1>{},
ck::wrapper::slice(K1)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
// Define and clear c vgpr register
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
ck::wrapper::clear(c_vgpr_reg);
// Local partitions for lds memory
auto a_lds_tensor_local_partition =
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
auto b_lds_tensor_local_partition =
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
// Lamda to slice tensor, then create local tile and partition
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
const auto k_slice =
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
ck::wrapper::slice(),
ck::wrapper::slice());
auto local_tile = ck::wrapper::make_local_tile(
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
};
auto a_global_local_partition = make_global_partition(
a_global_tensor,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
0);
auto b_global_local_partition = make_global_partition(
b_global_tensor,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
0);
// (row-major vgpr layout)
auto a_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(a_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
auto b_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(b_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
// Copy first values to lds
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
a_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
b_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
b_lds_tensor_local_partition);
// Pipeline loop
const ck::index_t num_loop =
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
// Skip if only tile should be processed
if(num_loop > 1)
{
ck::index_t i = 0;
do
{
auto a_global_local_partition_i = make_global_partition(
a_global_tensor,
make_tuple(
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
i + 1);
auto b_global_local_partition_i = make_global_partition(
b_global_tensor,
make_tuple(
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
i + 1);
// Copy data to A vgpr.
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_partition_i, a_vgpr_tensor);
// Synchronize.
ck::block_sync_lds();
// Copy data to B vgpr.
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_partition_i, b_vgpr_tensor);
// Perform gemm.
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
// Synchronize
ck::block_sync_lds();
// Copy data to A and B lds tiles.
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_vgpr_tensor, a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_vgpr_tensor, b_lds_tensor_local_partition);
++i;
} while(i < (num_loop - 1));
}
// Handle tail.
ck::block_sync_lds();
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
// Store data from C vgpr to C global memory.
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
bool DoPadding,
typename BlockShape,
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayout& thread_layout)
{
// Global memory buffers
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
}
int main(int argc, char* argv[])
{
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8, false>(
3840, 4096, 4096, tile_shape, thread_layout);
return 0;
}
// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s,
......@@ -12,10 +12,6 @@ Wrapper
Description
-------------------------------------
.. note::
The wrapper is under development and its functionality is limited.
The CK library provides a lightweight wrapper for more complex operations implemented in
the library.
......@@ -54,9 +50,15 @@ Output::
2 6 10 14 18 22 26 30
Tutorials:
* `GEMM tutorial <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/README.md>`_
Advanced examples:
* `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_
* `Basic gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp>`_
* `Optimized gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp>`_
-------------------------------------
Layout
......
......@@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
Sequence<false>,
Sequence<false>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
Sequence<true>,
Sequence<true>>{in_grid_desc,
make_tuple(src_tensor.GetMultiIdxOffsets()),
out_grid_desc,
make_tuple(dst_tensor.GetMultiIdxOffsets()),
tensor_operation::element_wise::PassThrough{}};
transfer.Run(tie(in_grid_desc),
tie(src_tensor.GetBuffer()),
......@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
{
// Perform copy from DynamicBuffer to StaticBuffer
const auto src_dst_slice_origin =
const auto dst_slice_origin_idxs =
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
[&](auto I) {
if constexpr(I == VectorDim)
{
return Number<ScalarPerVector>{};
}
else
{
return I1;
}
},
Number<num_dims>{});
auto transfer =
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
typename DstTensorType::TensorElementType,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
decltype(src_vector_tensor_lengths),
decltype(dim_access_order)>{
src_tensor.GetMultiIdxOffsets()};
auto transfer = ThreadwiseTensorSliceTransfer_v2<
std::remove_const_t<typename SrcTensorType::TensorElementType>,
std::remove_const_t<typename DstTensorType::TensorElementType>,
remove_cvref_t<decltype(in_grid_desc)>,
remove_cvref_t<decltype(out_grid_desc)>,
decltype(thread_slice_lengths),
decltype(dim_access_order),
VectorDim,
ScalarPerVector,
I1,
false,
false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
transfer.Run(in_grid_desc,
src_dst_slice_origin,
src_tensor.GetBuffer(),
out_grid_desc,
src_dst_slice_origin,
dst_slice_origin_idxs,
dst_tensor.GetBuffer());
}
else
......@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple,
index_t ScalarPerVector,
typename SrcTensorType,
typename DstTensorType,
typename ThreadLayoutTuple>
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
typename ThreadShape,
typename ThreadUnrolledDesc>
__device__ void
blockwise_copy(const SrcTensorType& src_tensor,
DstTensorType& dst_tensor,
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
{
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
......@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
constexpr auto tile_lengths_seq =
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq = generate_sequence_v2(
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
constexpr auto thread_layout_seq =
generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
constexpr auto dim_access_order = generate_sequence_v2(
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
// Perform copy between DynamicBuffers
auto transfer = ThreadGroupTensorSliceTransfer_v7<
......
......@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor()
/**
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
* data layout must be (NPerBlock, KPerBlock).
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
* (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
* or (K0PerBlock, NPerBlock, K1).
*
* \note C output Vgpr register layout (8D):
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
......@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor()
* \tparam BlockSize Tensor to pad.
* \tparam GemmTraits Traits of gemm xdl operation.
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
* (MPerBlock, KPerBlock) layout.
* (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
* (NPerBlock, KPerBlock) layout.
* (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
*/
template <typename DataType,
......@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
const BTensorType& b_local_tile_tensor,
CTensorType& c_reg_tensor)
{
constexpr auto I3 = Number<3>{};
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
......@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
DataType,
......@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
layout(c_local_tile_tensor).GetUnrolledDescriptor());
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
auto sliced_desc = transform_tensor_descriptor(
partition_desc,
make_tuple(
make_slice_transform(partition_shape.At(Number<0>{}),
m_thread_data_on_grid_idx[I0],
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<1>{}),
n_thread_data_on_grid_idx[I0],
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
make_slice_transform(partition_shape.At(Number<2>{}),
m_thread_data_on_grid_idx[I1],
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<3>{}),
n_thread_data_on_grid_idx[I1],
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
make_slice_transform(partition_shape.At(Number<4>{}),
m_thread_data_on_grid_idx[I2],
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
make_slice_transform(partition_shape.At(Number<5>{}),
m_thread_data_on_grid_idx[I3],
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
make_slice_transform(partition_shape.At(Number<6>{}),
m_thread_data_on_grid_idx[I4],
partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
make_slice_transform(partition_shape.At(Number<7>{}),
n_thread_data_on_grid_idx[I2],
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
lower_upper_dims,
lower_upper_dims);
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
partition_shape, partition_desc);
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, sliced_desc);
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
c_local_tile_tensor.GetPointer(), partition_layout);
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
n_thread_data_on_grid_idx[I0],
m_thread_data_on_grid_idx[I1],
n_thread_data_on_grid_idx[I1],
m_thread_data_on_grid_idx[I2],
m_thread_data_on_grid_idx[I3],
m_thread_data_on_grid_idx[I4],
n_thread_data_on_grid_idx[I2]));
return partition_tensor;
}
......@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
constexpr auto I6 = Number<6>{};
constexpr auto I7 = Number<7>{};
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
typename BTileLayout::LayoutShape{}.Size());
constexpr bool is_integer =
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
using ABlockDesc_K0_M_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
conditional_t<is_3d_desc,
typename ATileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
using BBlockDesc_K0_N_K1_Type =
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
conditional_t<is_3d_desc,
typename BTileLayout::LayoutUnrolledDescriptorType,
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
using BlockwiseGemmXdlops =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
......@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
vgpr_shape, vgpr_desc);
// Get vector type for Vgpr
using BlockwiseGemmCThreadBufferType =
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
vgpr_layout);
}
......
......@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
}
}
template <typename... Ts, typename Shape, typename FlattenDescriptor>
template <typename... Ts, typename Shape, typename UnrolledDescriptor>
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
const Shape& shape,
const FlattenDescriptor& flatten_desc)
const UnrolledDescriptor& flatten_desc)
{
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
......
......@@ -20,48 +20,57 @@ namespace wrapper {
* \tparam K1Value The number of K-dim elements that are packed together as
* a separate logical dimension. Usually aligns with vector load size.
*/
template <index_t MPerXDLValue,
index_t NPerXDLValue,
index_t MXdlPerWaveValue,
index_t NXdlPerWaveValue,
index_t K1Value>
template <typename MPerXDLValue,
typename NPerXDLValue,
typename MXdlPerWaveValue,
typename NXdlPerWaveValue,
typename K1Value>
struct BlockwisGemmXdlTraits
{
static constexpr index_t MPerXDL = MPerXDLValue;
static constexpr index_t NPerXDL = NPerXDLValue;
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
static constexpr index_t K1 = K1Value;
static constexpr auto MPerXDL = MPerXDLValue{};
static constexpr auto NPerXDL = NPerXDLValue{};
static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
static constexpr auto K1 = K1Value{};
};
// K1 = 4
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
{
};
// K1 = 8
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
{
};
// K1 = 16
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
{
};
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
{
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
namespace ck {
namespace wrapper {
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
} // namespace wrapper
} // namespace ck
......@@ -15,6 +15,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace ck {
namespace wrapper {
......@@ -29,6 +30,7 @@ template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple());
namespace {
namespace detail {
/**
* \brief Generate packed (column-major) strides if not passed
*
......@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
}
}
} // namespace detail
} // namespace
/// @endcond
......@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
template <typename Shape, typename Strides>
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, strides));
}
/**
......@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
template <typename Shape>
__host__ __device__ constexpr auto make_layout(const Shape& shape)
{
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
return Layout<Shape, UnrolledDescriptorType>(shape,
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
}
// Layout helpers
// get
/**
* \private
* \brief Get dim.
......@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
* \param layout Layout to create sub layout.
* \return Requsted sub layout.
*/
template <index_t idx, typename Shape, typename FlattenDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
template <index_t idx, typename Shape, typename UnrolledDesc>
__host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
{
const auto& shape = layout.GetShape();
const auto new_shape = get<idx>(shape);
......@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
return layout.GetShape();
}
// pad
/**
* \brief Pad layout shapes to be adjusted to tile lengths.
*
*
* \param layout Layout to pad.
* \param tile_lengths Tile lengths to align layout shape.
* \return Padded layout.
*/
template <typename Shape, typename UnrolledDesc, typename TileLengths>
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
const TileLengths& tile_lengths)
{
auto& unrolled_desc = layout.GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
// Create layout
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
}
// unmerge
/**
* \brief Unmerge selected dim in layout.
*
* \tparam Idx Index to dimension being unmerged.
* \param layout Layout to pad.
* \param new_lengths Dimensions into which the indicated dimension will be divided.
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
* \return Unmerged layout.
*/
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
const NewLengths& new_lengths,
[[maybe_unused]] const NewIdxs& new_indexes)
{
const auto& layout_shape = shape(layout);
auto& unrolled_desc = layout.GetUnrolledDescriptor();
constexpr auto dims = Shape::Size();
// Generate transforms
const auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == Idx)
{
return make_unmerge_transform(new_lengths);
}
else
{
return make_pass_through_transform(layout_shape.At(i));
}
},
Number<dims>{});
constexpr auto lower_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
constexpr auto upper_dims = generate_tuple(
[&](auto i) {
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
{
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
return to_sequence(idxs_tuple);
}
else
{
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
return Sequence<index>{};
}
},
Number<dims>{});
const auto unmerged_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
const auto unmerged_shape =
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
// Create layout
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
}
} // namespace wrapper
} // namespace ck
......@@ -6,7 +6,6 @@
#include "tensor_utils.hpp"
#include "layout_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
......@@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
* \brief Apply projection.
*
* \param base_tuple Tuple to apply projection.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Multi index after projection.
*/
template <typename MultiIndex, typename ProjectionTuple>
......@@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
}
else
{
return base_tuple.At(i_num);
return make_tuple(base_tuple.At(i_num));
}
},
Number<MultiIndex::Size()>{});
......@@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
* \brief Calculate shape with dims from projection.
*
* \param shape Base tensor shape.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Shape with dims from projection
*/
template <typename... Ts, typename... Ps>
......@@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
*
* \param shape Base tensor shape.
* \param tile_shape Tile shape.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tuple with blocks number.
*/
template <typename... Ts, typename... Ls, typename... Ps>
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
const Tuple<Ls...>& tile_shape,
const Tuple<Ps...>& projection)
const Tuple<Ls...>& tile_shape)
{
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
return generate_tuple(
[&](auto i) {
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
size<i>(tile_shape));
},
[&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
Number<Tuple<Ls...>::Size()>{});
}
......@@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
return thread_idxs * partition_lengths_seq + old_offset_idxs;
}
/**
* \brief Select dims to partition (skip if slice).
*
* \param block_idxs Input block indexes.
* \return Partitioned dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
{
const auto dims_to_partition = generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return Number<i>{};
}
else
{
return Tuple<>{};
}
},
Number<BlockIdxs::Size()>{});
// Remove empty tuples
return UnrollNestedTuple<0, 1>(dims_to_partition);
}
/**
* \brief Replace slices with zeros (Slice dims are not partitioned).
*
* \param block_idxs Input block indexes.
* \return Parsed dims.
*/
template <typename BlockIdxs>
__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
{
return generate_tuple(
[&](auto i) {
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
{
return block_idxs.At(i);
}
else
{
return Number<0>{};
}
},
Number<BlockIdxs::Size()>{});
}
/**
* \brief Calculate default projection.
*
......@@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
}
/**
* \brief Calculate thread multi index from 1d thread index.
*
* \param thread_layout Layout of threads (could not be nested).
* \param thread_id Thread index represented as integer.
* \return Multi index.
*/
template <typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto CalculateThreadMultiIdx(
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id)
{
static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
"Thread layout should not be transformed.");
constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
constexpr auto shape = ThreadShape{};
constexpr auto strides = embed_transform.coefficients_;
return generate_tuple(
[&](auto i) {
constexpr auto num_i = Number<i>{};
return (thread_id / strides.At(num_i)) % shape.At(num_i);
},
Number<ThreadShape::Size()>{});
}
} // namespace detail
} // namespace
......@@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
* is supported).
*
* \param tensor Tensor for partition.
* \param thread_lengths Layout of threads (could not be nested).
* \param thread_layout Layout of threads (could not be transformed).
* \param thread_id Thread index represented as integer.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
template <typename TensorType,
typename ThreadShape,
typename ThreadUnrolledDesc,
typename ProjectionTuple>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
const index_t thread_id,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
static_assert(!IsNestedTuple(ThreadShape{}));
// Calculate new partition shape
const auto& tensor_shape = shape(tensor);
// Calculate projected thread lengths
constexpr auto projected_thread_lengths =
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
constexpr auto partition_shape =
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
// Create Thread Cluster Descriptor
constexpr auto partition_shape_seq =
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
Number<decltype(partition_shape)::Size()>{});
constexpr auto thread_lengths_seq =
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
Number<ThreadLengthsTuple::Size()>{});
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
// Calculate thread idxs and offsets
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
// Apply projection on thread idxs to remove not needed idxs
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Slice descriptor
const auto transforms = generate_tuple(
[&](auto i) {
return make_slice_transform(partition_shape.At(i),
offset_multi_idxs.At(i),
partition_shape.At(i) + offset_multi_idxs.At(i));
},
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
const auto lower_upper_dims =
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
auto sliced_desc =
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
// Create layout
const auto partition_layout =
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
partition_shape, unrolled_desc);
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
partition_shape, sliced_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
// Apply offsets
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
return partition_tensor;
}
......@@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor,
* \param thread_id Thread index represented as integer.
* \return Partition tensor.
*/
template <typename TensorType, typename ThreadLengthsTuple>
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
const ThreadLengthsTuple& thread_lengths,
const index_t thread_id)
template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
__host__ __device__ constexpr auto
make_local_partition(TensorType& tensor,
const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
const index_t thread_id)
{
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
return make_local_partition(tensor, thread_lengths, thread_id, projection);
}
......@@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param projection Projection to remove selected dim from partitioning.
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \param projection Projection is used to remove selected dim from
* partitioning. Use `slice(X)` to remove dimension, where X is dim
* size. Use `Number<1>{}` to keep it.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
template <typename TensorType,
typename BlockShapeTuple,
typename BlockIdxs,
typename ProjectionTuple>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const index_t block_id,
const BlockIdxs& block_idxs,
const ProjectionTuple& projection)
{
static_assert(!IsNestedTuple(BlockShapeTuple{}));
constexpr bool is_default_projection =
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
static_assert(!IsNestedTuple(BlockIdxs{}));
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
......@@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
// TODO: Enable block_2_tile_map partitioning for non-default projection.
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
// Number of dims which are partitioned
constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
if constexpr(decltype(dims_to_partition)::Size() == I2)
{
// Optimized version for 2d tile shape [MxK]
const auto shape_with_projection_dims =
detail::CalculateShapeWithProjection(shape(tensor), projection);
// Set Value for M, N partition
const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
// Get 1D block id
const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
// Optimized version for 2d tile shape [MxN]
const auto block_2_tile_map =
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
BlockShapeTuple{}.At(I1),
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
NPerBlock,
remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
const index_t k_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
const auto offset_multi_idxs =
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
// Apply 0 for non partitioned dims
const auto offset_multi_idxs = generate_tuple(
[&](auto i) {
if constexpr(i == dims_to_partition.At(I0))
{
return m_block_data_idx_on_grid;
}
else if constexpr(i == dims_to_partition.At(I1))
{
return n_block_data_idx_on_grid;
}
else
{
return Number<0>{};
}
},
Number<BlockShapeTuple::Size()>{});
const auto projected_offset_multi_idxs =
detail::ApplyProjection(offset_multi_idxs, projection);
// Create new layout and tensor
const auto tile_layout =
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
aligned_desc);
Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
projected_tile_shape, aligned_desc);
auto tile_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
// Apply offsets
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
return tile_tensor;
}
else
{
// Calculate offsets
// Sequence with data to process per block
constexpr auto projected_tile_shape =
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
constexpr auto projected_tile_shape_seq =
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
Number<ProjectedTileShapeTuple::Size()>{});
// Tuple with number of blocks
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
const auto block_idxs =
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
const auto projected_block_idxs =
to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
// Create new layout and tensor
const auto tile_layout =
......@@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
*
* \param tensor Tensor for partition.
* \param tile_shape Shapes of requested tile.
* \param block_id Block index represented as integer.
* \param block_idxs Tuple of block indexes represented as integer. If slice,
* then get whole dim.
* \return Tile tensor.
*/
template <typename TensorType, typename BlockShapeTuple>
__host__ __device__ constexpr auto
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
const BlockShapeTuple& tile_shape,
const BlockIdxs& block_idxs)
{
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
return make_local_tile(tensor, tile_shape, block_id, projection);
}
/**
* \brief Pad tensor shapes to be adjusted to tile lengths.
*
*
* \param tensor Tensor to pad.
* \param tile_lengths Tile lengths to align tensor shape.
* \return Padded tensor.
*/
template <typename TensorType, typename TileLengths>
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
{
const auto& tensor_shape = shape(tensor);
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
// Generate sequence with ones to mark that all dims will be padded
constexpr auto do_pads_seq =
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
// Create descriptor with padding
auto padded_desc =
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
// Generate padded shape
const auto padded_shape = generate_tuple(
[&](auto i) {
const auto& dim = size<i>(tensor_shape);
const auto& tile_length = size<i>(tile_lengths);
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
},
Number<TileLengths::Size()>{});
// Create layout and tensor
const auto padded_layout =
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
auto partition_tensor =
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
return partition_tensor;
return make_local_tile(tensor, tile_shape, block_idxs, projection);
}
} // namespace wrapper
......
add_gtest_executable(test_layout test_layout.cpp)
target_link_libraries(test_layout PRIVATE utility)
add_gtest_executable(test_tensor test_tensor.cpp)
target_link_libraries(test_tensor PRIVATE utility)
add_gtest_executable(test_copy test_copy.cpp)
target_link_libraries(test_copy PRIVATE utility)
add_gtest_executable(test_partition test_partition.cpp)
target_link_libraries(test_partition PRIVATE utility)
add_custom_target(test_wrapper)
add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp)
target_link_libraries(test_wrapper_layout PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_layout)
add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp)
target_link_libraries(test_wrapper_tensor PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_tensor)
add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp)
target_link_libraries(test_wrapper_copy PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_copy)
add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp)
target_link_libraries(test_wrapper_partition PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_partition)
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
GPU_TARGETS MATCHES "gfx942")
add_gtest_executable(test_gemm test_gemm.cpp)
target_link_libraries(test_gemm PRIVATE utility)
add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp)
target_link_libraries(test_wrapper_gemm PRIVATE utility)
add_dependencies(test_wrapper test_wrapper_gemm)
endif()
......@@ -20,23 +20,25 @@
template <typename InputTensor,
typename OutputTensor,
typename BlockShape,
typename ThreadLayoutShape,
typename ThreadLayout,
bool UseOptimizedCopy>
__global__ void TestCopyDevice(const InputTensor input_tensor,
OutputTensor output_tensor,
const BlockShape tile_shape,
const ThreadLayoutShape thread_layout)
const ThreadLayout thread_layout)
{
__shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)];
const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
p_shared, ck::wrapper::make_layout(tile_shape));
const auto block_idx = static_cast<ck::index_t>(blockIdx.x);
const auto block_idxs =
ck::make_tuple(static_cast<ck::index_t>(blockIdx.x), static_cast<ck::index_t>(blockIdx.y));
// Get local tiles for global memory
const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
const auto input_local_tile =
ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
const auto output_local_tile =
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
// Get partition per thread
const auto input_local_partition =
......@@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor,
// Allocate VGPR
auto tensor_vgpr =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
layout(lds_local_partition));
ck::wrapper::make_layout(shape(lds_local_partition)));
// Perform copy
if constexpr(UseOptimizedCopy)
......@@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS()
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout);
const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{});
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}));
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
const ck::index_t grid_size = ck::math::integer_divide_ceil(
ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape));
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(
ck::wrapper::size<0>(input_tensor_global), ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(
ck::wrapper::size<1>(input_tensor_global), ck::wrapper::size<1>(tile_shape));
const auto kernel = TestCopyDevice<decltype(input_tensor_global),
decltype(output_tensor_global),
......@@ -112,7 +117,7 @@ void PerformCopyGlobalToGlobalViaLDS()
UseOptimizedCopy>;
launch_and_time_kernel(StreamConfig{},
kernel,
dim3(grid_size),
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
input_tensor_global,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <numeric>
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <vector>
#include <gtest/gtest.h>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
template <typename DataType>
void CheckResult(const std::vector<DataType>& a_data,
const std::vector<DataType>& b_data,
std::vector<DataType>& c_m_n_device_result,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
a_m_k.mData = a_data;
b_k_n.mData = b_data;
auto ref_op = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
}
template <bool DoPad, typename Layout, typename PaddingDims>
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
{
if constexpr(DoPad)
{
return ck::wrapper::pad(layout, padding_dims);
}
else
{
return layout;
}
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
typename BlockShape,
typename ThreadLayout,
bool DoPadding>
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
const void* p_b,
void* p_c,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape tile_shape,
const ThreadLayout thread_layout)
{
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
constexpr auto K1 = GemmTraits::K1;
constexpr auto K0PerBlock = KPerBlock / K1;
const auto K0 = ck::math::integer_divide_ceil(K, K1);
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
const auto a_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
const auto b_global_layout =
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
const auto c_global_layout =
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
auto a_padded_global_layout =
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
auto b_padded_global_layout =
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
auto c_padded_global_layout =
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
// Reshape from M,K to K0,M,K1
const auto reshaped_dims_idxs =
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
auto a_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto b_padded_unmerged_global_layout =
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
static_cast<DataType*>(p_c), c_padded_global_layout);
// Add extra M and N
constexpr auto a_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, MPerBlock, K1),
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
constexpr auto b_tile_layout = ck::wrapper::make_layout(
ck::make_tuple(K0PerBlock, NPerBlock, K1),
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + NPerBlock];
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + NPerBlock];
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_a), a_tile_layout);
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
static_cast<DataType*>(lds_b), b_tile_layout);
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
static_cast<ck::index_t>(blockIdx.x),
static_cast<ck::index_t>(blockIdx.y),
ck::wrapper::slice());
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
constexpr ck::index_t vector_dim = 2;
auto c_global_local_tile =
ck::wrapper::make_local_tile(c_global_tensor,
tile_shape_k0_m_n_k1,
block_idxs,
make_tuple(ck::wrapper::slice(K0PerBlock),
ck::Number<1>{},
ck::Number<1>{},
ck::wrapper::slice(K1)));
auto c_global_local_partition =
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>(c_global_local_tile);
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
decltype(a_tile_layout),
decltype(b_tile_layout),
ck::wrapper::size(thread_layout),
GemmTraits>();
ck::wrapper::clear(c_vgpr_reg);
auto a_lds_tensor_local_partition =
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
auto b_lds_tensor_local_partition =
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
const auto k_slice =
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
ck::wrapper::slice(),
ck::wrapper::slice());
auto local_tile = ck::wrapper::make_local_tile(
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
};
auto a_global_local_partition = make_global_partition(
a_global_tensor,
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
0);
auto b_global_local_partition = make_global_partition(
b_global_tensor,
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
0);
// (row-major vgpr layout)
auto a_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(a_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
auto b_vgpr_tensor =
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
ck::wrapper::make_layout(
shape(b_global_local_partition),
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
ck::wrapper::size<2>(a_global_local_partition),
ck::wrapper::size<2>(a_global_local_partition),
ck::Number<1>{})));
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
a_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
b_vgpr_tensor);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
b_lds_tensor_local_partition);
const ck::index_t num_loop =
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
if(num_loop > 1)
{
ck::index_t i = 0;
do
{
auto a_global_local_partition_i = make_global_partition(
a_global_tensor,
make_tuple(
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
i + 1);
auto b_global_local_partition_i = make_global_partition(
b_global_tensor,
make_tuple(
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
i + 1);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_global_local_partition_i, a_vgpr_tensor);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_global_local_partition_i, b_vgpr_tensor);
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::block_sync_lds();
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
a_vgpr_tensor, a_lds_tensor_local_partition);
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
b_vgpr_tensor, b_lds_tensor_local_partition);
++i;
} while(i < (num_loop - 1));
}
ck::block_sync_lds();
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
}
template <typename DataType,
typename GemmTraits,
ck::index_t scalar_per_vector,
bool DoPadding,
typename BlockShape,
typename ThreadLayout>
void PerformGemm(const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const BlockShape& tile_shape,
const ThreadLayout& thread_layout)
{
// Global memory buffers
DeviceMem a_mem(M * K * sizeof(DataType));
DeviceMem b_mem(K * N * sizeof(DataType));
DeviceMem c_mem(M * N * sizeof(DataType));
std::vector<DataType> a_data(M * K);
std::vector<DataType> b_data(K * N);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
a_mem.ToDevice(a_data.data());
b_mem.ToDevice(b_data.data());
c_mem.SetZero();
const ck::index_t grid_size_x =
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
const ck::index_t grid_size_y =
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
const auto kernel =
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
kernel,
dim3(grid_size_x, grid_size_y, 1),
dim3(ck::wrapper::size(thread_layout)),
0,
a_mem.GetDeviceBuffer(),
b_mem.GetDeviceBuffer(),
c_mem.GetDeviceBuffer(),
M,
N,
K,
tile_shape,
thread_layout);
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << std::endl;
std::vector<DataType> c_data(M * N);
c_mem.FromDevice(c_data.data());
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
}
TEST(TestGemm, Float)
{
using DataType = float;
// (dim1, dim2, dim0 thread layout)
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Int8)
{
using DataType = int8_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
PerformGemm<DataType,
ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1,
16,
false>(512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Half)
{
using DataType = ck::half_t;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<32>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8, false>(
512, 512, 128, tile_shape, thread_layout);
// Irregular case
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1, true>(
129, 129, 67, tile_shape, thread_layout);
}
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
{
using DataType = float;
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<16>{});
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4, false>(
512, 512, 128, tile_shape, thread_layout);
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
......
......@@ -29,8 +29,11 @@ TEST(TestPartition, LocalPartition)
const auto tensor =
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{});
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
// row-major thread layout
const auto thread_layout =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}),
ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}, ck::Number<1>{}));
// 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim)
const auto thread_projection =
ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{});
......@@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile)
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{});
const auto block_projection =
ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2));
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto num_blocks =
const auto grid_shape =
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape));
std::vector<ck::index_t> block_idxs(ck::wrapper::size(num_blocks));
std::iota(block_idxs.begin(), block_idxs.end(), 0);
std::vector<ck::Tuple<ck::index_t, ck::index_t, ck::index_t, ck::index_t>> block_idxs;
for(int i = 0; i < ck::wrapper::size<0>(grid_shape); i++)
{
for(int j = 0; j < ck::wrapper::size<1>(grid_shape); j++)
{
for(int k = 0; k < ck::wrapper::size<2>(grid_shape); k++)
{
block_idxs.emplace_back(i, j, k, 0);
}
}
}
for(auto block_idx : block_idxs)
{
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
const auto packed_tile =
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection);
const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim;
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) *
auto expected_tile_first_val = ck::wrapper::size<2>(block_idx) *
ck::wrapper::size<2>(block_shape) *
ck::wrapper::size<2>(strides);
block_idx /= ck::wrapper::size<2>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) *
expected_tile_first_val += ck::wrapper::size<1>(block_idx) *
ck::wrapper::size<1>(block_shape) *
ck::wrapper::size<1>(strides);
block_idx /= ck::wrapper::size<1>(num_blocks);
expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) *
expected_tile_first_val += ck::wrapper::size<0>(block_idx) *
ck::wrapper::size<0>(block_shape) *
ck::wrapper::size<0>(strides);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment