Commit fa649421 authored by Jing Zhang's avatar Jing Zhang
Browse files

finished api

parent e845ad4c
...@@ -79,27 +79,15 @@ struct ExecutionConfig final ...@@ -79,27 +79,15 @@ struct ExecutionConfig final
bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
{ {
int group_count = problem_size.group_count; auto group_count = problem_size.group_count;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmDesc> gemm_descs;
std::vector<const void*> p_a, p_b; std::vector<void*> p_Cs;
std::vector<void*> p_c;
gemm_descs.reserve(group_count); gemm_descs.reserve(group_count);
for(int i = 0; i < group_count; i++) int sum_of_m = 0;
{
int M = problem_size.Ms[i];
int N = problem_size.Ns[i];
int K = problem_size.Ks[i];
int stride_A = problem_size.stride_As[i];
int stride_B = problem_size.stride_Bs[i];
int stride_C = problem_size.stride_Cs[i];
gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, {}});
}
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -135,21 +123,22 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -135,21 +123,22 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(int i = 0; i < group_count; i++)
{ {
sum_of_m += problem_size.Ms[i];
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{}))); problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{})));
c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor( c_host_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor( c_device_tensors.push_back(Tensor<EDataType>(f_host_tensor_descriptor(
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl; << std::endl;
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_; flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i];
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize();
...@@ -171,22 +160,47 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -171,22 +160,47 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
} }
} }
for(std::size_t i = 0; i < gemm_descs.size(); i++) using GemmKernelArgument = ck::tensor_operation::device::GemmKernelArgument;
std::vector<GemmKernelArgument> simple_gemm_kernel_args_;
simple_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; i++)
{ {
a_tensors_device.emplace_back(std::make_unique<DeviceMem>( a_tensors_device.emplace_back(
sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); std::make_unique<DeviceMem>(sizeof(ADataType) * sum_of_m * problem_size.Ks[i]));
b_tensors_device.emplace_back(std::make_unique<DeviceMem>( b_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i]));
c_tensors_device.emplace_back(std::make_unique<DeviceMem>(
sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize())); c_tensors_device.emplace_back(
std::make_unique<DeviceMem>(sizeof(EDataType) * sum_of_m * problem_size.Ns[i]));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(),
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType));
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(),
b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType));
c_tensors_device[i]->SetZero(); c_tensors_device[i]->SetZero();
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); gemm_descs.push_back({sum_of_m,
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i],
{}});
simple_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
b_tensors_device[i]->GetDeviceBuffer(),
c_tensors_device[i]->GetDeviceBuffer(),
problem_size.Ms[i],
problem_size.Ns[i],
problem_size.Ks[i],
problem_size.stride_As[i],
problem_size.stride_Bs[i],
problem_size.stride_Cs[i]});
} }
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -196,17 +210,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -196,17 +210,24 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
std::vector<const void*> p_As = {};
std::vector<const void*> p_Bs = {};
std::vector<std::array<const void*, 0>> p_Ds = {}; std::vector<std::array<const void*, 0>> p_Ds = {};
// do GEMM // do GEMM
auto argument = gemm.MakeArgument( auto argument = gemm.MakeArgument(
p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
gemm.SetKBatchSize(argument, 8); hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
simple_gemm_kernel_args_.data(),
gemm.GetWorkSpaceSize(&argument),
hipMemcpyHostToDevice));
gemm.SetKBatchSize(argument, 4);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
...@@ -215,7 +236,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -215,7 +236,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"); "not support this GEMM problem");
} }
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, gemm_desc_workspace.GetDeviceBuffer(), StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
...@@ -230,7 +251,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -230,7 +251,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(),
c_device_tensors[i].mDesc.GetElementSize() *
sizeof(EDataType));
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -249,7 +272,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -249,7 +272,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
if(config.time_kernel) if(config.time_kernel)
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float ave_time = invoker.Run(argument,
gemm_desc_workspace.GetDeviceBuffer(),
StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
...@@ -267,7 +292,8 @@ int main(int argc, char* argv[]) ...@@ -267,7 +292,8 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; problem_size.Ms = {
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
......
...@@ -8,6 +8,20 @@ namespace ck { ...@@ -8,6 +8,20 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct GemmKernelArgument
{
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
......
...@@ -83,8 +83,15 @@ __global__ void ...@@ -83,8 +83,15 @@ __global__ void
const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC); const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, StrideC);
const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch}; const auto local_b2c_tile_map = Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, k_batch};
const auto m_loops = local_b2c_tile_map.CalculateMLoops(c_grid_desc_m_n);
index_t m_id = 0;
do
{
const auto block_2_ctile_map = const auto block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_id * block_size); GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_id * block_size, m_id);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
p_a_grid, p_a_grid,
...@@ -103,6 +110,11 @@ __global__ void ...@@ -103,6 +110,11 @@ __global__ void
k_batch, k_batch,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
block_2_ctile_map); block_2_ctile_map);
m_id += 1;
} while(m_id < m_loops);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -267,11 +279,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -267,11 +279,21 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_ = 0; grid_size_ = 0;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) && if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) ||
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) && 0 == ck::type_convert<ck::index_t>(p_As.size())))
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{ {
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); throw std::runtime_error("wrong! group_count_ != p_As || 0 != p_As.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) ||
0 == ck::type_convert<ck::index_t>(p_Bs.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Bs || 0 != p_Bs.size");
}
if(!(group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_Es");
} }
gemm_kernel_args_.reserve(group_count_); gemm_kernel_args_.reserve(group_count_);
...@@ -297,17 +319,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -297,17 +319,13 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
grid_size_ += grid_size_grp;
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp; auto karg = KernelArgument{
p_As.size() == 0 ? nullptr : type_convert<const ADataType*>(p_As[i]),
// block-to-e-tile map p_Bs.size() == 0 ? nullptr : type_convert<const BDataType*>(p_Bs[i]),
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
auto karg = KernelArgument{type_convert<const ADataType*>(p_As[i]),
type_convert<const BDataType*>(p_Bs[i]),
type_convert<EDataType*>(p_Es[i]), type_convert<EDataType*>(p_Es[i]),
M, M,
N, N,
...@@ -349,16 +367,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -349,16 +367,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto local_b2c_tile_map = const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
grid_size_ += grid_size_grp;
const index_t block_start = grid_size_; const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp; const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KPadded = k_padded; karg.KPadded = k_padded;
karg.K0 = k0; karg.K0 = k0;
karg.k_batch = K_BATCH; karg.k_batch = K_BATCH;
...@@ -378,30 +391,64 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -378,30 +391,64 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
struct SimpleGemmArgument
{
const void* p_a_grid;
const void* p_b_grid;
void* p_c_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
};
float Run(const Argument& arg, float Run(const Argument& arg,
const void* gemm_descs_dev, const void* gemm_descs_dev,
const StreamConfig& stream_config = StreamConfig{}) const StreamConfig& stream_config = StreamConfig{})
{ {
using GemmArgumentType = SimpleGemmArgument; using GemmArgumentType = GemmKernelArgument;
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0; index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& karg = arg.gemm_kernel_args_[i].karg_;
if(stream_config.log_level_ > 0)
{
karg.Print();
}
std::cout << "Group id: " << i << " block_size: "
<< arg.gemm_kernel_args_[0].block_end_ -
arg.gemm_kernel_args_[0].block_start_
<< std::endl;
auto kbatch = karg.k_batch;
if(!GridwiseGemm::CheckValidity(karg))
{
std::ostringstream err;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
K0 = karg.K0;
bool not_all_have_main_k0_block_loop_same =
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k0_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << kbatch
<< ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
float ave_time = 0; float ave_time = 0;
const auto Run = [&](const auto& kernel) { const auto Run = [&](const auto& kernel) {
...@@ -491,60 +538,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -491,60 +538,19 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
std::vector<SimpleGemmArgument> simple_gemm_kernel_args_; std::vector<GemmKernelArgument> grouped_gemm_kernel_args_;
simple_gemm_kernel_args_.reserve(arg.gemm_kernel_args_.size()); grouped_gemm_kernel_args_.reserve(arg.gemm_kernel_args_.size());
index_t K0 = arg.gemm_kernel_args_[0].karg_.K0;
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1;
bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{ {
const auto& karg = arg.gemm_kernel_args_[i].karg_; const auto& karg = arg.gemm_kernel_args_[i].karg_;
if(stream_config.log_level_ > 0)
{
karg.Print();
}
auto kbatch = karg.k_batch;
std::cout << "Group id: " << i << " block_size: " if(karg.p_a_grid == nullptr || karg.p_b_grid == nullptr || karg.p_c_grid == nullptr)
<< arg.gemm_kernel_args_[i].block_end_ -
arg.gemm_kernel_args_[i].block_start_
<< std::endl;
if(!GridwiseGemm::CheckValidity(karg))
{ {
std::ostringstream err; throw std::runtime_error("wrong! p_a/b/c_grid is nullptr");
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
K0 = karg.K0;
bool not_all_have_main_k0_block_loop_same =
all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1);
if(not_all_have_main_k0_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! "
<< "group [" << i << "], kbatch: " << kbatch
<< ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch
<< " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
} }
simple_gemm_kernel_args_.push_back({karg.p_a_grid, grouped_gemm_kernel_args_.push_back({karg.p_a_grid,
karg.p_b_grid, karg.p_b_grid,
karg.p_c_grid, karg.p_c_grid,
karg.M, karg.M,
...@@ -555,12 +561,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -555,12 +561,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
karg.StrideC}); karg.StrideC});
} }
using GemmArgumentType = SimpleGemmArgument; using GemmArgumentType = GemmKernelArgument;
hip_check_error( hip_check_error(
hipMemcpyWithStream(arg.p_workspace_, hipMemcpyWithStream(arg.p_workspace_,
simple_gemm_kernel_args_.data(), grouped_gemm_kernel_args_.data(),
simple_gemm_kernel_args_.size() * sizeof(GemmArgumentType), grouped_gemm_kernel_args_.size() * sizeof(GemmArgumentType),
hipMemcpyHostToDevice, hipMemcpyHostToDevice,
stream_config.stream_id_)); stream_config.stream_id_));
......
...@@ -315,6 +315,11 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt ...@@ -315,6 +315,11 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
__host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; }
__device__ index_t CalculateMLoops(const CGridDesc_M_N& c_grid_desc_m_n) const
{
return math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock);
}
private: private:
index_t M01_; index_t M01_;
index_t KSplit_; index_t KSplit_;
...@@ -586,17 +591,22 @@ struct OffsettedBlockToCTileMap ...@@ -586,17 +591,22 @@ struct OffsettedBlockToCTileMap
using underlying_type = UnderlyingBlockToCTileMap; using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start) index_t block_start,
index_t mblock_id_off = 0)
{ {
block_to_ctile_map_ = block_to_ctile_map; block_to_ctile_map_ = block_to_ctile_map;
block_start_ = block_start; block_start_ = block_start;
mblock_id_off_ = mblock_id_off;
} }
template <typename TopIdx> template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{ {
return block_to_ctile_map_.CalculateBottomIndex( auto idx_bot = block_to_ctile_map_.CalculateBottomIndex(
make_multi_index(idx_top[Number<0>{}] - block_start_)); make_multi_index(idx_top[Number<0>{}] - block_start_));
return make_tuple(
idx_bot[Number<0>{}], idx_bot[Number<1>{}] + mblock_id_off_, idx_bot[Number<2>{}]);
} }
template <typename CTileIdx, typename CTileDim> template <typename CTileIdx, typename CTileDim>
...@@ -620,6 +630,7 @@ struct OffsettedBlockToCTileMap ...@@ -620,6 +630,7 @@ struct OffsettedBlockToCTileMap
UnderlyingBlockToCTileMap block_to_ctile_map_; UnderlyingBlockToCTileMap block_to_ctile_map_;
index_t block_start_; index_t block_start_;
index_t mblock_id_off_;
}; };
/** /**
......
...@@ -621,9 +621,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -621,9 +621,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
return; return;
} }
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]);
const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR // HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid = const index_t m_block_data_idx_on_grid =
......
...@@ -25,7 +25,9 @@ struct DeviceMem ...@@ -25,7 +25,9 @@ struct DeviceMem
void* GetDeviceBuffer() const; void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const; std::size_t GetBufferSize() const;
void ToDevice(const void* p) const; void ToDevice(const void* p) const;
void ToDevice(const void* p, const std::size_t cpySize) const;
void FromDevice(void* p) const; void FromDevice(void* p) const;
void FromDevice(void* p, const std::size_t cpySize) const;
void SetZero() const; void SetZero() const;
template <typename T> template <typename T>
void SetValue(T x) const; void SetValue(T x) const;
......
...@@ -19,11 +19,21 @@ void DeviceMem::ToDevice(const void* p) const ...@@ -19,11 +19,21 @@ void DeviceMem::ToDevice(const void* p) const
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice)); hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
} }
void DeviceMem::ToDevice(const void* p, const std::size_t cpySize) const
{
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
}
void DeviceMem::FromDevice(void* p) const void DeviceMem::FromDevice(void* p) const
{ {
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
} }
void DeviceMem::FromDevice(void* p, const std::size_t cpySize) const
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); }
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment