"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "6e2c6159ac8b0567902334a28ceae7b0e6829305"
Commit a4b08b57 authored by Adam Osewski's avatar Adam Osewski
Browse files

Generalize kernel to grouped_gemm and add more test cases.

parent 7316bd15
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream> #include <memory>
#include <vector>
#include <gtest/gtest.h> #include <gtest/gtest.h>
...@@ -19,17 +20,34 @@ ...@@ -19,17 +20,34 @@
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace {
using namespace ck; using namespace ck;
namespace { struct GemmArgDesc
{
GemmArgDesc(index_t M_,
index_t N_,
index_t K_,
const float* p_A_,
const float* p_B_,
float* p_C_,
index_t tile_count_)
: M{M_}, N{N_}, K{K_}, p_A{p_A_}, p_B{p_B_}, p_C{p_C_}, tile_count{tile_count_}
{
}
index_t M;
index_t N;
index_t K;
const float* p_A;
const float* p_B;
float* p_C;
index_t tile_count;
};
template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock> template <index_t MPerBlock, index_t NPerBlock, index_t KPerBlock>
__global__ void gemm_naive_strided_tile_loop_reduce(index_t M, __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p_gemm_descs,
index_t N,
index_t K,
const float* p_A,
const float* p_B,
float* p_C,
float* p_workspace, float* p_workspace,
uint32_t* p_flags, uint32_t* p_flags,
index_t tile_count, index_t tile_count,
...@@ -39,6 +57,41 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M, ...@@ -39,6 +57,41 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M,
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
StridedReductionTileLoop work_scheduler{tile_count, p_flags}; StridedReductionTileLoop work_scheduler{tile_count, p_flags};
// early exit if no work.
if(work_scheduler.tile_id_ >= tile_count)
return;
index_t group_id = 0;
index_t offset = 0;
index_t grid_size_grp = p_gemm_descs[group_id].tile_count;
index_t gemm_tile_id_start = 0;
index_t gemm_tile_id_end = grid_size_grp;
do
{
// Find corresponding GEMM group for out tile
while(!(work_scheduler.tile_id_ >= gemm_tile_id_start &&
work_scheduler.tile_id_ < gemm_tile_id_end))
{
// Step to next GEMM group and update data tile bounds.
offset += grid_size_grp;
group_id++;
grid_size_grp = p_gemm_descs[group_id].tile_count;
gemm_tile_id_start = offset;
gemm_tile_id_end = offset + grid_size_grp;
}
const index_t M = p_gemm_descs[group_id].M;
const index_t N = p_gemm_descs[group_id].N;
const index_t K = p_gemm_descs[group_id].K;
const auto p_A = p_gemm_descs[group_id].p_A;
const auto p_B = p_gemm_descs[group_id].p_B;
const auto p_C = p_gemm_descs[group_id].p_C;
const auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); const auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N));
BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock> b2c_tile_map(c_grid_desc_m_n, k_batch); BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock> b2c_tile_map(c_grid_desc_m_n, k_batch);
...@@ -53,7 +106,8 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M, ...@@ -53,7 +106,8 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M,
const index_t stride_c = N; const index_t stride_c = N;
// K is the contiguous dim in memory, as well as fastest changing dim in B2C mapping. // K is the contiguous dim in memory, as well as fastest changing dim in B2C mapping.
const auto block_work_idx = b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_); const auto block_work_idx =
b2c_tile_map.CalculateBottomIndex(work_scheduler.tile_id_ - offset);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
...@@ -79,30 +133,37 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M, ...@@ -79,30 +133,37 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M,
} while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx()); } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if next [M,N] tile // if next [M,N] tile
if(!b2c_tile_map.IsFirstKSplitBlock(work_scheduler.tiles_per_block_)) if(!b2c_tile_map.IsFirstKSplitBlock())
{ {
// Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory. // Assume we have MPerBlock x NPerBlock tile per each workgroup in contiguous memory.
p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] = p_workspace[get_block_1d_id() * MPerBlock * NPerBlock + get_thread_local_1d_id()] =
partial_result; partial_result;
} }
work_scheduler.FlagFinished(k_batch, b2c_tile_map.GetOutputTileIdx()); const index_t output_tile_idx = b2c_tile_map.GetOutputTileIdx();
const index_t output_tile_idx_offset = offset / k_batch;
work_scheduler.FlagFinished(k_batch, output_tile_idx, output_tile_idx_offset);
// The workgroup which processed first K tile accumulates results and stores to GMEM // The workgroup which processed first K tile accumulates results and stores to GMEM
if(b2c_tile_map.IsFirstKSplitBlock(work_scheduler.tiles_per_block_)) if(b2c_tile_map.IsFirstKSplitBlock())
{ {
// Wait untill all other blocks for this [M,N] tile store their results. // Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler.WaitForNeighbours(k_batch, b2c_tile_map.GetOutputTileIdx()); work_scheduler.WaitForNeighbours(k_batch, output_tile_idx, output_tile_idx_offset);
// accumulate partial results // Accumulate partial results. We can have different # of workgroups to reduce, thus we
const index_t workgroups_per_dim = // read actual flag value.
(k_batch + work_scheduler.tiles_per_block_ - 1) / work_scheduler.tiles_per_block_; const index_t flag_v = __builtin_amdgcn_readfirstlane(
for(index_t i = 0; i < workgroups_per_dim; ++i) work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
for(index_t i = 1; i < flag_v; ++i)
{ {
partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock + partial_result += p_workspace[(get_block_1d_id()) * MPerBlock * NPerBlock +
i * MPerBlock * NPerBlock + get_thread_local_1d_id()]; i * MPerBlock * NPerBlock + get_thread_local_1d_id()];
} }
// Signal waiting blocks that they can start use their workspace.
work_scheduler.Reset(k_batch, output_tile_idx, output_tile_idx_offset);
// write result // write result
const index_t C_m_tile_offset = block_m_id * MPerBlock; const index_t C_m_tile_offset = block_m_id * MPerBlock;
const index_t C_thread_tile_m_idx = get_thread_local_1d_id() / NPerBlock; const index_t C_thread_tile_m_idx = get_thread_local_1d_id() / NPerBlock;
...@@ -112,10 +173,14 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M, ...@@ -112,10 +173,14 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M,
p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset + p_C[(C_m_tile_offset + C_thread_tile_m_idx) * stride_c + C_n_tile_offset +
C_thread_tile_n_idx] = partial_result; C_thread_tile_n_idx] = partial_result;
} }
else
{
work_scheduler.WaitForReduction(k_batch, output_tile_idx, output_tile_idx_offset);
}
} while(work_scheduler.HasTile());
#else #else
ignore = p_input; ignore = p_gemm_descs;
ignore = p_output;
ignore = p_workspace; ignore = p_workspace;
ignore = p_flags; ignore = p_flags;
ignore = tile_count; ignore = tile_count;
...@@ -126,7 +191,7 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M, ...@@ -126,7 +191,7 @@ __global__ void gemm_naive_strided_tile_loop_reduce(index_t M,
} // namespace } // namespace
template <index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock> template <index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock>
struct GemmStridedTileLoopReduce struct GroupedGemmStridedTileLoopReduce
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -139,7 +204,7 @@ struct GemmStridedTileLoopReduce ...@@ -139,7 +204,7 @@ struct GemmStridedTileLoopReduce
using AccDataType = float; using AccDataType = float;
constexpr static auto DeviceGemmKernel = constexpr static auto DeviceGemmKernel =
gemm_naive_strided_tile_loop_reduce<MPerBlock, NPerBlock, KPerBlock>; grouped_gemm_naive_strided_tile_loop_reduce<MPerBlock, NPerBlock, KPerBlock>;
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
...@@ -149,32 +214,75 @@ struct GemmStridedTileLoopReduce ...@@ -149,32 +214,75 @@ struct GemmStridedTileLoopReduce
BElementOp, BElementOp,
CElementOp>; CElementOp>;
GemmStridedTileLoopReduce() = default; GroupedGemmStridedTileLoopReduce() = default;
bool Run(index_t M, index_t N, index_t K, index_t k_batch) bool Run(std::vector<index_t> Ms,
std::vector<index_t> Ns,
std::vector<index_t> Ks,
index_t k_batch,
index_t grid_size)
{ {
Tensor<float> a_m_k(HostTensorDescriptor({M, K}, {K, 1})); EXPECT_TRUE(Ms.size() == Ns.size() && Ms.size() == Ks.size());
Tensor<float> b_k_n(HostTensorDescriptor({K, N}, {N, 1})); std::size_t group_count = Ms.size();
std::vector<Tensor<float>> a_m_k;
std::vector<Tensor<float>> b_k_n;
std::vector<Tensor<float>> c_m_n_host;
std::vector<Tensor<float>> c_m_n_device;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); std::vector<DeviceMemPtr> a_m_k_device_buf;
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); std::vector<DeviceMemPtr> b_k_n_device_buf;
std::vector<DeviceMemPtr> c_m_n_device_buf;
Tensor<float> c_m_n_host(HostTensorDescriptor({M, N}, {N, 1})); std::vector<GemmArgDesc> gemm_descs;
Tensor<float> c_m_n_device(HostTensorDescriptor({M, N}, {N, 1})); gemm_descs.reserve(group_count);
DeviceMem a_m_k_device_buf(sizeof(float) * a_m_k.mDesc.GetElementSpaceSize()); index_t tile_count = 0;
DeviceMem b_k_n_device_buf(sizeof(float) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(float) * c_m_n_device.mDesc.GetElementSpaceSize());
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); for(std::size_t i = 0; i < group_count; ++i)
b_k_n_device_buf.ToDevice(b_k_n.mData.data()); {
c_m_n_device_buf.SetZero(); a_m_k.push_back(Tensor<float>(HostTensorDescriptor({Ms[i], Ks[i]}, {Ks[i], 1})));
c_m_n_host.SetZero(); b_k_n.push_back(Tensor<float>(HostTensorDescriptor({Ks[i], Ns[i]}, {Ns[i], 1})));
c_m_n_host.push_back(Tensor<float>(HostTensorDescriptor({Ms[i], Ns[i]}, {Ns[i], 1})));
c_m_n_device.push_back(Tensor<float>(HostTensorDescriptor({Ms[i], Ns[i]}, {Ns[i], 1})));
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k[i]);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n[i]);
c_m_n_host[i].SetZero();
c_m_n_device[i].SetZero();
a_m_k_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(float) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_k_n_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(float) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_m_n_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(float) * c_m_n_device[i].mDesc.GetElementSpaceSize()));
a_m_k_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_k_n_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_m_n_device_buf[i]->SetZero();
BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock> b2c_tile_map(Ms[i], Ns[i], k_batch);
index_t grp_tile_count = b2c_tile_map.CalculateGridSize(Ms[i], Ns[i]);
tile_count += grp_tile_count;
gemm_descs.emplace_back(
Ms[i],
Ns[i],
Ks[i],
reinterpret_cast<float*>(a_m_k_device_buf[i]->GetDeviceBuffer()),
reinterpret_cast<float*>(b_k_n_device_buf[i]->GetDeviceBuffer()),
reinterpret_cast<float*>(c_m_n_device_buf[i]->GetDeviceBuffer()),
grp_tile_count);
}
DeviceMem gemm_descs_device_buf{gemm_descs.size() * sizeof(GemmArgDesc)};
gemm_descs_device_buf.ToDevice(gemm_descs.data());
DeviceMem gemm_workspace, gemm_flags; DeviceMem gemm_workspace, gemm_flags;
BlockToCTileMap_LinearKSplit<MPerBlock, NPerBlock> b2c_tile_map(M, N, k_batch);
const index_t tile_count = b2c_tile_map.CalculateGridSize(M, N);
const index_t grid_size = tile_count / 4;
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size; const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
// This is the number of MN-output tiles which we cover with workgroups. // This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile. // We launch k_batch / tiles_per_block workgroups for each output tile.
...@@ -186,17 +294,13 @@ struct GemmStridedTileLoopReduce ...@@ -186,17 +294,13 @@ struct GemmStridedTileLoopReduce
gemm_workspace.SetZero(); gemm_workspace.SetZero();
gemm_flags.SetZero(); gemm_flags.SetZero();
launch_and_time_kernel(StreamConfig{nullptr, false}, launch_and_time_kernel(
StreamConfig{nullptr, false},
DeviceGemmKernel, DeviceGemmKernel,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
M, reinterpret_cast<const GemmArgDesc*>(gemm_descs_device_buf.GetDeviceBuffer()),
N,
K,
reinterpret_cast<const float*>(a_m_k_device_buf.GetDeviceBuffer()),
reinterpret_cast<const float*>(b_k_n_device_buf.GetDeviceBuffer()),
reinterpret_cast<float*>(c_m_n_device_buf.GetDeviceBuffer()),
reinterpret_cast<float*>(gemm_workspace.GetDeviceBuffer()), reinterpret_cast<float*>(gemm_workspace.GetDeviceBuffer()),
reinterpret_cast<uint32_t*>(gemm_flags.GetDeviceBuffer()), reinterpret_cast<uint32_t*>(gemm_flags.GetDeviceBuffer()),
tile_count, tile_count,
...@@ -209,48 +313,155 @@ struct GemmStridedTileLoopReduce ...@@ -209,48 +313,155 @@ struct GemmStridedTileLoopReduce
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
bool pass = true;
for(std::size_t i = 0; i < group_count; ++i)
{
auto ref_argument = ref_gemm.MakeArgument( auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host, a_element_op, b_element_op, c_element_op); a_m_k[i], b_k_n[i], c_m_n_host[i], a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
c_m_n_device_buf.FromDevice(c_m_n_device.mData.data()); c_m_n_device_buf[i]->FromDevice(c_m_n_device[i].mData.data());
pass = pass && ck::utils::check_err(c_m_n_device[i], c_m_n_host[i]);
}
return ck::utils::check_err(c_m_n_device, c_m_n_host); return pass;
} }
}; };
TEST(TestStridedReductionTileLoop, SingleDataTile) TEST(TestStridedReductionTileLoop, GroupedGemm_SingleDataTile)
{ {
constexpr index_t MPerBlock = 8; constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
const index_t kbatch = 4; const index_t kbatch = 4;
const index_t grid_size = 4;
std::vector<index_t> Ms(1, MPerBlock);
std::vector<index_t> Ns(1, NPerBlock);
std::vector<index_t> Ks(1, KPerBlock * kbatch);
EXPECT_TRUE((GemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run( EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
MPerBlock, NPerBlock, KPerBlock * kbatch, kbatch))); Ms, Ns, Ks, kbatch, grid_size)));
} }
TEST(TestStridedReductionTileLoop, SingleOutputMultipleDataTiles) TEST(TestStridedReductionTileLoop, GroupedGemm_SingleOutputMultipleDataTiles)
{ {
constexpr index_t MPerBlock = 8; constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
const index_t kbatch = 16; const index_t kbatch = 16;
const index_t grid_size = 4;
EXPECT_TRUE((GemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run( std::vector<index_t> Ms(1, MPerBlock);
MPerBlock, NPerBlock, KPerBlock * kbatch, kbatch))); std::vector<index_t> Ns(1, NPerBlock);
std::vector<index_t> Ks(1, KPerBlock * kbatch);
EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
Ms, Ns, Ks, kbatch, grid_size)));
} }
TEST(TestStridedReductionTileLoop, MultipleDataTiles) TEST(TestStridedReductionTileLoop, GroupedGemm_MultipleDataTiles)
{ {
constexpr index_t MPerBlock = 8; constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32; constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32; constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256; constexpr index_t BlockSize = 256;
const index_t kbatch = 16; const index_t kbatch = 16;
const index_t grid_size = 64;
std::vector<index_t> Ms(1, MPerBlock * 4);
std::vector<index_t> Ns(1, NPerBlock * 4);
std::vector<index_t> Ks(1, KPerBlock * kbatch);
EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
Ms, Ns, Ks, kbatch, grid_size)));
}
TEST(TestStridedReductionTileLoop, GroupedGemm_MultipleOutputDataTilesPerBlock_1Group)
{
constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256;
const index_t kbatch = 6;
const index_t grid_size = 3;
std::vector<index_t> Ms(1, MPerBlock * 2);
std::vector<index_t> Ns(1, NPerBlock);
std::vector<index_t> Ks(1, KPerBlock * kbatch);
EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
Ms, Ns, Ks, kbatch, grid_size)));
}
TEST(TestStridedReductionTileLoop, GroupedGemm_MultipleOutputDataTilesPerBlock_NGroup)
{
constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256;
const index_t kbatch = 6;
const index_t grid_size = 6;
std::vector<index_t> Ms(2, MPerBlock * 2);
std::vector<index_t> Ns(2, NPerBlock);
std::vector<index_t> Ks(2, KPerBlock * kbatch);
EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
Ms, Ns, Ks, kbatch, grid_size)));
}
TEST(TestStridedReductionTileLoop, GroupedGemm_CrossGroups_CrossK_TilePerBlockLTKBatch)
{
constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256;
const index_t kbatch = 5;
const index_t grid_size = 7;
std::vector<index_t> Ms(2, MPerBlock * 2);
std::vector<index_t> Ns(2, NPerBlock);
std::vector<index_t> Ks(2, KPerBlock * kbatch);
EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
Ms, Ns, Ks, kbatch, grid_size)));
}
TEST(TestStridedReductionTileLoop, GroupedGemm_CrossGroups_CrossK_TilePerBlockGTKBatch)
{
constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256;
const index_t kbatch = 5;
const index_t grid_size = 5;
std::vector<index_t> Ms(2, MPerBlock * 2);
std::vector<index_t> Ns(2, NPerBlock * 2);
std::vector<index_t> Ks(2, KPerBlock * kbatch);
EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
Ms, Ns, Ks, kbatch, grid_size)));
}
TEST(TestStridedReductionTileLoop, GroupedGemm_CrossGroups_CrossK_TilePerBlockGTKBatch2)
{
constexpr index_t MPerBlock = 8;
constexpr index_t NPerBlock = 32;
constexpr index_t KPerBlock = 32;
constexpr index_t BlockSize = 256;
const index_t kbatch = 5;
// The covered number of tiles is more than actual data tiles.
const index_t grid_size = 6;
std::vector<index_t> Ms(2, MPerBlock * 2);
std::vector<index_t> Ns(2, NPerBlock * 2);
std::vector<index_t> Ks(2, KPerBlock * kbatch);
EXPECT_TRUE((GemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run( EXPECT_TRUE((GroupedGemmStridedTileLoopReduce<BlockSize, MPerBlock, NPerBlock, KPerBlock>{}.Run(
MPerBlock * 4, NPerBlock * 4, KPerBlock * kbatch, kbatch))); Ms, Ns, Ks, kbatch, grid_size)));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment