Unverified Commit b6eaf3eb authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Pass gemm_descs for grouped gemm via __constant__ buff (#232)

* moved gemm_descs_args into const buff

* use CK_CONSTANT_ADDRESS_SPACE instead of global constant

* clean

* moved hipMemAlloc outside of deviceOp

* add SetWorkSpacePointer

* fix ignore
parent 7b1e2c37
...@@ -78,7 +78,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
int group_count = 4; int group_count = rand() % 16 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
...@@ -189,12 +189,17 @@ int main(int argc, char* argv[]) ...@@ -189,12 +189,17 @@ int main(int argc, char* argv[])
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
// do GEMM
auto argument = auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
throw std::runtime_error( throw std::runtime_error(
......
...@@ -42,6 +42,8 @@ struct BaseOperator ...@@ -42,6 +42,8 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument*, void*) const {}
virtual ~BaseOperator() {} virtual ~BaseOperator() {}
}; };
......
...@@ -24,57 +24,33 @@ template <typename GridwiseGemm, ...@@ -24,57 +24,33 @@ template <typename GridwiseGemm,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
bool HasMainKBlockLoop, bool HasMainKBlockLoop>
index_t MaxGroupCount>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdlops_v2r3( kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_descs, const index_t group_count,
const index_t group_count, const AElementwiseOperation a_element_op,
const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op,
const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
#if 1 const auto gemm_desc_ptr =
static_for<0, MaxGroupCount, 1>{}([&](auto i) { reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ &&
i < group_count)
{
auto group_id = i;
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_descs[group_id].a_ptr,
gemm_descs[group_id].b_ptr,
gemm_descs[group_id].c_ptr,
p_shared,
gemm_descs[group_id].a_grid_desc_k0_m_k1_,
gemm_descs[group_id].b_grid_desc_k0_n_k1_,
gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
a_element_op,
b_element_op,
c_element_op,
gemm_descs[group_id].grouped_gemm_block_2_ctile_map_);
}
});
#else
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_descs);
index_t group_id = 0; index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { for(index_t i = 0; i < group_count; i++)
group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd && {
i < group_count) group_id =
? i (block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_)
: group_id; ? i
}); : group_id;
}
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr, gemm_desc_ptr[group_id].a_ptr,
...@@ -87,11 +63,9 @@ __global__ void ...@@ -87,11 +63,9 @@ __global__ void
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_, gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_);
block_id_grp);
#endif
#else #else
ignore = gemm_descs; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
ignore = a_element_op; ignore = a_element_op;
ignore = b_element_op; ignore = b_element_op;
...@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl ...@@ -388,6 +362,8 @@ struct DeviceGroupedGemmXdl
{ {
grid_size_ = 0; grid_size_ = 0;
gemm_descs_args_workspace_ = nullptr;
group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size()); group_count_ = ck::type_convert<ck::index_t>(gemm_shapes.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) && if(!(group_count_ == ck::type_convert<ck::index_t>(p_a.size()) &&
...@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl ...@@ -461,6 +437,8 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
void* gemm_descs_args_workspace_;
index_t grid_size_; index_t grid_size_;
}; };
...@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl ...@@ -471,49 +449,49 @@ struct DeviceGroupedGemmXdl
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
StaticallyIndexedArray<GemmDescKernelArg, MaxGroupCount> gemm_desc_kernel_args;
bool has_main_k_block_loop = true; bool has_main_k_block_loop = true;
static_for<0, MaxGroupCount, 1>{}([&](auto i) { for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
if(i < arg.gemm_desc_kernel_arg_.size()) {
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_,
arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_))
{ {
gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i]; throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{"
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}";
std::cout << ", arg.b_grid_desc_k0_n_k1_{"
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}";
std::cout << ", arg.c_grid_desc_m_n_{ "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
if(!GridwiseGemm::CheckValidity(
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_,
gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_,
gemm_desc_kernel_args[i].c_grid_desc_m_n_,
gemm_desc_kernel_args[i].grouped_gemm_block_2_ctile_map_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
} }
});
const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) *
arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
{
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
}
}
hipGetErrorString(
hipMemcpy(arg.gemm_descs_args_workspace_,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg),
hipMemcpyHostToDevice));
float ave_time = 0; float ave_time = 0;
...@@ -523,23 +501,23 @@ struct DeviceGroupedGemmXdl ...@@ -523,23 +501,23 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDescKernelArg>, GemmDescKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
true, true>;
MaxGroupCount>;
ave_time = launch_and_time_kernel(
ave_time = launch_and_time_kernel(stream_config, stream_config,
kernel, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_args, cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
} }
else else
{ {
...@@ -547,23 +525,23 @@ struct DeviceGroupedGemmXdl ...@@ -547,23 +525,23 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm, kernel_grouped_gemm_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
remove_reference_t<GemmDescKernelArg>, GemmDescKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
false, false>;
MaxGroupCount>;
ave_time = launch_and_time_kernel(
ave_time = launch_and_time_kernel(stream_config, stream_config,
kernel, kernel,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_args, cast_pointer_to_constant_address_space(arg.gemm_descs_args_workspace_),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
} }
return ave_time; return ave_time;
...@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl ...@@ -652,6 +630,16 @@ struct DeviceGroupedGemmXdl
return str.str(); return str.str();
} }
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmDescKernelArg);
}
void SetWorkSpacePointer(BaseArgument* p_arg, void* workspace_ptr) const override
{
dynamic_cast<Argument*>(p_arg)->gemm_descs_args_workspace_ = workspace_ptr;
}
}; };
} // namespace device } // namespace device
......
...@@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -141,10 +141,15 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
auto c_element_op = PassThrough{}; auto c_element_op = PassThrough{};
// do GEMM // do GEMM
auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer();
auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( auto argument_ptr = groupedGemmPtr->MakeArgumentPointer(
p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get()));
groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
invoker_ptr->Run(argument_ptr.get()); invoker_ptr->Run(argument_ptr.get());
for(std::size_t i = 0; i < gemm_shapes.size(); i++) for(std::size_t i = 0; i < gemm_shapes.size(); i++)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment