Commit e86a2748 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments: unified blk2ctile; add test

parent 8df7bd01
......@@ -76,9 +76,9 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......
......@@ -223,6 +223,35 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector,
NumPrefetch>;
struct GroupedGemmBlock2CTileMap
{
GroupedGemmBlock2CTileMap()
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1);
BlockStart_ = -1;
}
GroupedGemmBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n,
index_t M01,
index_t N01,
ck::index_t BlockStart)
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01);
BlockStart_ = BlockStart;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
return block_2_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[I0] - BlockStart_));
}
private:
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
ck::index_t BlockStart_;
};
struct GemmDescKernelArg
{
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
......@@ -232,21 +261,21 @@ struct DeviceGroupedGemmXdl
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_;
const ADataType* a_ptr;
const BDataType* b_ptr;
CDataType* c_ptr;
ck::index_t BlockStart, BlockEnd;
ck::index_t BlockStart_, BlockEnd_;
};
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
Argument(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
index_t M01,
index_t N01,
......@@ -301,15 +330,15 @@ struct DeviceGroupedGemmXdl
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
const auto block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
const auto grouped_gemm_block_2_ctile_map_ =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart);
gemm_desc_kernel_arg_.push_back(
GemmDescKernelArg{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
block_2_ctile_map_,
grouped_gemm_block_2_ctile_map_,
static_cast<const ADataType*>(p_a[i]),
static_cast<const BDataType*>(p_b[i]),
static_cast<CDataType*>(p_c[i]),
......@@ -470,9 +499,9 @@ struct DeviceGroupedGemmXdl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
static auto MakeArgument(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape> gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -484,9 +513,9 @@ struct DeviceGroupedGemmXdl
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......
......@@ -80,11 +80,10 @@ __global__ void
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
if(block_id >= gemm_desc_[i].BlockStart_ && block_id < gemm_desc_[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
gemm_desc_[group_id].a_ptr,
......@@ -97,8 +96,7 @@ __global__ void
a_element_op,
b_element_op,
c_element_op,
gemm_desc_[group_id].block_2_ctile_map_,
block_id_grp);
gemm_desc_[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
......@@ -426,8 +424,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CElementwiseOperation& c_element_op,
const Block2CTileMap& block_2_ctile_map,
ck::index_t block_id = get_block_1d_id())
const Block2CTileMap& block_2_ctile_map)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
......@@ -440,7 +437,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// divide block work by [M, N]
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_id));
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
......
......@@ -35,6 +35,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(gemm_split_k)
add_subdirectory(conv2d_fwd)
add_subdirectory(convnd_fwd)
......
add_test_executable(test_grouped_gemm_fp16 grouped_gemm_fp16.cpp)
target_link_libraries(test_grouped_gemm_fp16 PRIVATE host_tensor)
target_link_libraries(test_grouped_gemm_fp16 PRIVATE device_grouped_gemm_instance)
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "config.hpp"
#include "print.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_gemm.hpp"
#include "device_tensor.hpp"
#include "device_grouped_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "reference_gemm.hpp"
#include "gemm_specialization.hpp"
#include "test_util.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using DeviceGroupedGemmPtr_ = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_grouped_gemm_instance {
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGroupedGemmPtr_>&);
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace {
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = ck::tensor_layout::gemm::RowMajor;
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
using CLayout = ck::tensor_layout::gemm::RowMajor;
template <typename T>
static bool check_err(const Tensor<T>& ref, const Tensor<T>& result)
{
float max_diff = 1e-2;
for(int i = 0; i < ref.mData.size(); ++i)
{
float diff = std::abs(double(ref.mData[i]) - double(result.mData[i]));
if(max_diff < diff)
{
std::cout << double(ref.mData[i]) << "," << double(result.mData[i]) << std::endl;
return false;
}
}
return true;
}
bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
{
int group_count = 4;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
gemm_shapes.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
int AStride = std::is_same<ck::tensor_layout::gemm::RowMajor, ALayout>::value ? K : M;
int BStride = std::is_same<ck::tensor_layout::gemm::RowMajor, BLayout>::value ? N : K;
int CStride = std::is_same<ck::tensor_layout::gemm::RowMajor, CLayout>::value ? N : M;
gemm_shapes.push_back({M, N, K, AStride, BStride, CStride});
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
std::vector<Tensor<ADataType>> a_tensors;
;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{})));
b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_shapes[i].K, gemm_shapes[i].N, gemm_shapes[i].StrideB, BLayout{})));
c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_shapes[i].M, gemm_shapes[i].N, gemm_shapes[i].StrideC, CLayout{})));
// std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
//<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
//<< std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize()));
b_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize()));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto c_element_op = PassThrough{};
// do GEMM
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
if(!groupedGemmPtr->IsSupportedArgument(argument_ptr.get()))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
invoker_ptr->Run(argument_ptr.get());
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, PassThrough, PassThrough, PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
bool res = check_err(c_device_tensors[i], c_host_tensors[i]);
std::cout << "group_id: " << i << (res ? " SUCCESS" : " FAILURE") << std::endl;
if(!res)
return false;
}
return true;
}
} // anonymous namespace
int main()
{
std::vector<DeviceGroupedGemmPtr_> groupedGemmPtrs;
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(groupedGemmPtrs);
bool res = true;
for(auto& gemmPtr : groupedGemmPtrs)
{
res &= TestGroupedGemm(gemmPtr);
}
std::cout << "TestGroupedGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment