Commit f4f94f70 authored by Jing Zhang's avatar Jing Zhang
Browse files

merge group and non-group

parent bb9c4a89
......@@ -10,7 +10,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_grouped_gemm_xdlops_v2r3.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gemm_specialization.hpp"
namespace ck {
......@@ -182,7 +182,7 @@ struct DeviceGroupedGemmXdl
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
......
......@@ -54,6 +54,82 @@ __global__ void
block_2_ctile_map);
}
template <typename GridwiseGemm,
typename FloatAB,
typename FloatC,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
bool HasMainK0BlockLoop,
index_t MaxGroupCount>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
{
auto group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].block_2_ctile_map_,
block_id_grp);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
gemm_desc_ptr[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
}
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
......@@ -350,7 +426,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
ck::index_t block_id = get_block_1d_id())
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -363,7 +440,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
......
......@@ -69,10 +69,12 @@ void profile_grouped_gemm_impl(int do_verification,
}
};
if(!(Ms.size() == Ns.size() && Ns.size() == Ks.size() && Ks.size() == StrideAs.size() &&
StrideAs.size() == StrideBs.size() && StrideBs.size() == StrideCs.size()))
int group_count = Ms.size();
if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() &&
group_count == StrideBs.size() && group_count == StrideCs.size()))
{
throw std::runtime_error("wrong! inconsistent Ms, Ns, Ks, StrideA/B/Cs size\n");
throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n");
}
std::vector<Tensor<ADataType>> a_m_k;
......@@ -125,9 +127,22 @@ void profile_grouped_gemm_impl(int do_verification,
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_device_buf, b_device_buf, c_device_buf;
std::vector<GemmShape> gemm_shapes;
a_device_buf.reserve(group_count);
b_device_buf.reserve(group_count);
c_device_buf.reserve(group_count);
for(int i = 0; i < Ms.size(); i++)
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
p_a.reserve(group_count);
p_b.reserve(group_count);
p_c.reserve(group_count);
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
gemm_shapes.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_device_buf.push_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSize()));
......@@ -141,15 +156,11 @@ void profile_grouped_gemm_impl(int do_verification,
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
gemm_shapes.push_back({Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
StrideCs[i],
a_device_buf[i]->GetDeviceBuffer(),
b_device_buf[i]->GetDeviceBuffer(),
c_device_buf[i]->GetDeviceBuffer()});
gemm_shapes.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i]});
p_a.push_back(a_device_buf[i]->GetDeviceBuffer());
p_b.push_back(b_device_buf[i]->GetDeviceBuffer());
p_c.push_back(c_device_buf[i]->GetDeviceBuffer());
}
// add device GEMM instances
......@@ -204,7 +215,10 @@ void profile_grouped_gemm_impl(int do_verification,
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(gemm_shapes,
gemm_ptr->MakeArgumentPointer(p_a,
p_b,
p_c,
gemm_shapes,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment