Commit 3165d5d7 authored by Jing Zhang's avatar Jing Zhang
Browse files

add examples

parent fee2002c
......@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
......@@ -57,7 +58,233 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
#include "run_grouped_gemm_example.inc"
struct ProblemSize final
{
std::vector<ck::index_t> Ms;
std::vector<ck::index_t> Ns;
std::vector<ck::index_t> Ks;
std::vector<ck::index_t> stride_As;
std::vector<ck::index_t> stride_Bs;
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{
auto group_count = problem_size.group_count;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<void*> p_Cs;
gemm_descs.reserve(group_count);
int sum_of_m = 0;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<EDataType>> c_host_tensors;
std::vector<Tensor<EDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < group_count; i++)
{
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
switch(config.init_method)
{
case 0: break;
case 1:
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
case 2:
a_tensors[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
}
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>;
std::vector<GroupedGemmKernelArgument> grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
a_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
c_tensors_device[i]->SetZero();
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{}});
grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
{},
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
{},
problem_size.stride_Cs[i]});
}
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CDEElementOp{};
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM
auto argument = gemm.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
gemm.GetWorkSpaceSize(&argument),
hipMemcpyHostToDevice));
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
invoker.Run(argument, gemm_desc_workspace.GetDeviceBuffer(), StreamConfig{nullptr, false});
bool pass = true;
if(config.do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
EDataType,
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_tensors[i],
b_tensors[i],
c_host_tensors[i],
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
}
}
if(config.time_kernel)
{
float ave_time = invoker.Run(argument,
gemm_desc_workspace.GetDeviceBuffer(),
StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
}
return pass;
}
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[])
......
......@@ -459,9 +459,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return math::integer_divide_ceil(M_, MPerBlock_);
}
__host__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
__host__ static constexpr index_t CalculateGridSize(index_t /*M*/, index_t N)
{
const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto M0 = 1; // math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return M0 * N0;
......@@ -566,11 +566,27 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) ||
0 == ck::type_convert<ck::index_t>(p_As.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
throw std::runtime_error("wrong! group_count_ != p_As || 0 != p_As.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) ||
0 == ck::type_convert<ck::index_t>(p_Bs.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Ds.size()) ||
0 == ck::type_convert<ck::index_t>(p_Ds.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Ds || 0 != p_Ds.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Es");
}
gemm_desc_kernel_arg_.reserve(group_count_);
......@@ -605,9 +621,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
DsGridDesc_M_N ds_grid_desc_m_n;
std::array<index_t, NumDTensor> StrideDs;
static_for<0, NumDTensor, 1>{}([&](auto j) {
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
StrideDs[j] = gemm_descs[i].stride_Ds_[j];
ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
M, N, gemm_descs[i].stride_Ds_[j]);
});
......@@ -655,8 +674,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
e_grid_desc_m_n);
gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{p_As[i],
p_Bs[i],
GemmBiasTransKernelArg{p_As.size() == 0 ? nullptr : p_As[i],
p_Bs.size() == 0 ? nullptr : p_Bs[i],
p_ds_grid,
p_Es[i],
M,
......@@ -664,7 +683,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
K,
StrideA,
StrideB,
{},
StrideDs,
StrideC,
a_grid_desc_m_k,
b_grid_desc_n_k,
......@@ -702,6 +721,58 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg,
const void* grouped_gemm_kernel_args_dev,
const StreamConfig& stream_config = StreamConfig{})
{
bool has_main_k_block_loop = true;
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(grouped_gemm_kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
bool has_main_k_block_loop = true;
......@@ -753,10 +824,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
if(arg.gemm_desc_kernel_arg_[i].a_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].b_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].e_ptr_ == nullptr)
{
throw std::runtime_error("wrong! p_a/b/c_grid is nullptr");
}
grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_,
{},
arg.gemm_desc_kernel_arg_[i].ds_ptr_,
arg.gemm_desc_kernel_arg_[i].e_ptr_,
arg.gemm_desc_kernel_arg_[i].M_,
arg.gemm_desc_kernel_arg_[i].N_,
......@@ -774,48 +852,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
hipMemcpyHostToDevice,
stream_config.stream_id_));
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
float ave_time = Run(arg, arg.p_workspace_, stream_config);
return ave_time;
}
......@@ -932,7 +969,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
return dynamic_cast<const Argument*>(p_arg)->group_count_ *
sizeof(GroupedGemmKernelArgument<NumDTensor>);
}
};
......
......@@ -25,7 +25,9 @@ struct DeviceMem
void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const;
void ToDevice(const void* p) const;
void ToDevice(const void* p, const std::size_t cpySize) const;
void FromDevice(void* p) const;
void FromDevice(void* p, const std::size_t cpySize) const;
void SetZero() const;
template <typename T>
void SetValue(T x) const;
......
......@@ -19,11 +19,21 @@ void DeviceMem::ToDevice(const void* p) const
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
void DeviceMem::ToDevice(const void* p, const std::size_t cpySize) const
{
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
void DeviceMem::FromDevice(void* p, const std::size_t cpySize) const
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment