"...composable_kernel-1.git" did not exist on "0f912e205eec6e349060f2203a8eeabc5e7ba075"
Commit 5a5468f4 authored by Jing Zhang's avatar Jing Zhang
Browse files

add SetDeviceKernelArgs

parent 3165d5d7
...@@ -60,6 +60,8 @@ int main() ...@@ -60,6 +60,8 @@ int main()
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
int sum_of_m = 0;
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
Ms.push_back(256 + 256 * distrib(gen)); Ms.push_back(256 + 256 * distrib(gen));
...@@ -69,6 +71,8 @@ int main() ...@@ -69,6 +71,8 @@ int main()
StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]); StrideAs.push_back(std::is_same<Row, ALayout>::value ? Ks[i] : Ms[i]);
StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]); StrideBs.push_back(std::is_same<Row, BLayout>::value ? Ns[i] : Ks[i]);
StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]); StrideEs.push_back(std::is_same<Row, ELayout>::value ? Ns[i] : Ms[i]);
sum_of_m += Ms[i];
} }
auto f_matrix_space_size = auto f_matrix_space_size =
...@@ -102,6 +106,10 @@ int main() ...@@ -102,6 +106,10 @@ int main()
gemm_descs.reserve(group_count); gemm_descs.reserve(group_count);
std::vector<ck::tensor_operation::device::GroupedGemmKernelArgument<>>
grouped_gemm_kernel_args_;
grouped_gemm_kernel_args_.reserve(group_count);
for(int i = 0; i < group_count; ++i) for(int i = 0; i < group_count; ++i)
{ {
a_dev_bufs.emplace_back(sizeof(ADataType) * a_dev_bufs.emplace_back(sizeof(ADataType) *
...@@ -111,11 +119,23 @@ int main() ...@@ -111,11 +119,23 @@ int main()
e_dev_bufs.emplace_back(sizeof(EDataType) * e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {}}); gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {}});
p_a.push_back(a_dev_bufs[i].GetDeviceBuffer()); p_a.push_back(a_dev_bufs[i].GetDeviceBuffer());
p_b.push_back(b_dev_bufs[i].GetDeviceBuffer()); p_b.push_back(b_dev_bufs[i].GetDeviceBuffer());
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(),
b_dev_bufs[i].GetDeviceBuffer(),
{},
e_dev_bufs[i].GetDeviceBuffer(),
Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
{},
StrideEs[i]});
} }
using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout, using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm<ALayout,
...@@ -162,13 +182,20 @@ int main() ...@@ -162,13 +182,20 @@ int main()
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get())); SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); // op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(),
op_ptr->GetWorkSpaceSize(argument_ptr.get()),
hipMemcpyHostToDevice);
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float ave_time = invoker_ptr->Run(argument_ptr.get(),
gemm_desc_workspace.GetDeviceBuffer(),
StreamConfig{nullptr, true});
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(std::size_t j = 0; j < gemm_descs.size(); ++j) for(std::size_t j = 0; j < gemm_descs.size(); ++j)
......
...@@ -223,7 +223,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -223,7 +223,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); // gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(), hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(), grouped_gemm_kernel_args_.data(),
...@@ -237,7 +237,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -237,7 +237,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"); "not support this GEMM problem");
} }
invoker.Run(argument, gemm_desc_workspace.GetDeviceBuffer(), StreamConfig{nullptr, false}); gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
bool pass = true; bool pass = true;
if(config.do_verification) if(config.do_verification)
...@@ -273,9 +275,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -273,9 +275,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
if(config.time_kernel) if(config.time_kernel)
{ {
float ave_time = invoker.Run(argument, float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
gemm_desc_workspace.GetDeviceBuffer(),
StreamConfig{nullptr, config.time_kernel});
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
......
...@@ -66,6 +66,8 @@ struct DeviceGroupedGemm : public BaseOperator ...@@ -66,6 +66,8 @@ struct DeviceGroupedGemm : public BaseOperator
CElementwiseOperation c_element_op) = 0; CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0;
}; };
} // namespace device } // namespace device
......
...@@ -564,6 +564,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -564,6 +564,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{ {
grid_size_ = 0; grid_size_ = 0;
grouped_gemm_kernel_args_dev = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) || if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) ||
...@@ -713,6 +715,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -713,6 +715,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_; std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_; std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
const void* grouped_gemm_kernel_args_dev;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -721,65 +725,15 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -721,65 +725,15 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{ {
using Argument = DeviceOp::Argument; using Argument = DeviceOp::Argument;
float Run(const Argument& arg,
const void* grouped_gemm_kernel_args_dev,
const StreamConfig& stream_config = StreamConfig{})
{
bool has_main_k_block_loop = true;
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(grouped_gemm_kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
bool has_main_k_block_loop = true; bool has_main_k_block_loop = true;
#if 1
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args; std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size()); grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
#endif
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
...@@ -824,13 +778,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -824,13 +778,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
} }
if(arg.gemm_desc_kernel_arg_[i].a_ptr_ == nullptr || #if 1
arg.gemm_desc_kernel_arg_[i].b_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].e_ptr_ == nullptr)
{
throw std::runtime_error("wrong! p_a/b/c_grid is nullptr");
}
grouped_gemm_kernel_args.push_back( grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_, GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_, arg.gemm_desc_kernel_arg_[i].b_ptr_,
...@@ -843,16 +791,80 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -843,16 +791,80 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
arg.gemm_desc_kernel_arg_[i].StrideB_, arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_, arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_}); arg.gemm_desc_kernel_arg_[i].StrideE_});
#endif
} }
hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, float ave_time = 0;
grouped_gemm_kernel_args.data(),
grouped_gemm_kernel_args.size() * auto launch_kernel = [&](auto has_main_k_block_loop_) {
sizeof(GroupedGemmKernelArgument<NumDTensor>), const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
hipMemcpyHostToDevice, GroupedGemmKernelArgument<NumDTensor>,
stream_config.stream_id_)); GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
has_main_k_block_loop_>;
const index_t grid_size_grp = arg.gemm_desc_kernel_arg_[0].BlockEnd_ -
arg.gemm_desc_kernel_arg_[0].BlockStart_;
const void* kernel_args_dev = nullptr;
if(arg.grouped_gemm_kernel_args_dev != nullptr)
{
kernel_args_dev = arg.grouped_gemm_kernel_args_dev;
}
else
{
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{
if(arg.gemm_desc_kernel_arg_[i].a_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].b_ptr_ == nullptr ||
arg.gemm_desc_kernel_arg_[i].e_ptr_ == nullptr)
{
throw std::runtime_error("wrong! p_a/b/c_grid is nullptr");
}
}
hipGetErrorString(
hipMemcpyWithStream(arg.p_workspace_,
grouped_gemm_kernel_args.data(),
grouped_gemm_kernel_args.size() *
sizeof(GroupedGemmKernelArgument<NumDTensor>),
hipMemcpyHostToDevice,
stream_config.stream_id_));
kernel_args_dev = arg.p_workspace_;
}
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(kernel_args_dev),
arg.gemm_desc_kernel_arg_.size(),
grid_size_grp,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
float ave_time = Run(arg, arg.p_workspace_, stream_config); if(has_main_k_block_loop)
{
ave_time = launch_kernel(integral_constant<bool, true>{});
}
else
{
ave_time = launch_kernel(integral_constant<bool, false>{});
}
return ave_time; return ave_time;
} }
...@@ -967,6 +979,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -967,6 +979,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
return str.str(); return str.str();
} }
static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args)
{
arg.grouped_gemm_kernel_args_dev = kernel_args;
}
// polymorphic
void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override
{
return SetDeviceKernelArgs(*dynamic_cast<Argument*>(p_arg), kernel_args);
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{ {
return dynamic_cast<const Argument*>(p_arg)->group_count_ * return dynamic_cast<const Argument*>(p_arg)->group_count_ *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment