Commit c13776be authored by rocking's avatar rocking
Browse files

1. Allocate mean, var and count into by SetWorkSpacePointer.

2. Add GetWorkSpaceSize to calculate the space size
parent 5215f11d
......@@ -175,6 +175,10 @@ int main()
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
size_t workspace_sz = device_op.GetWorkSpaceSize(&argument);
DeviceMem workspace_dev(workspace_sz);
device_op.SetWorkSpacePointer(&argument, workspace_dev.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
if(do_verification)
......
......@@ -481,9 +481,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_welford_mean_grid_{nullptr},
p_welford_var_grid_{nullptr},
p_welford_count_grid_{nullptr},
p_workspace_mean_{nullptr},
p_workspace_var_{nullptr},
p_workspace_count_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
......@@ -510,14 +510,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
// TODO - GetWorkSpaceSize(), let user hipMalloc the memory
int gemm_welford_size = MRaw * gemm_nblock_;
hip_check_error(
hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(
hipMalloc(&p_welford_var_grid_, sizeof(VarDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_welford_count_grid_, sizeof(int32_t) * gemm_welford_size));
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
......@@ -568,9 +560,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
MeanDataType* p_welford_mean_grid_;
VarDataType* p_welford_var_grid_;
int32_t* p_welford_count_grid_;
void* p_workspace_mean_;
void* p_workspace_var_;
void* p_workspace_count_;
const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_;
......@@ -682,9 +674,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.p_welford_mean_grid_,
arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
static_cast<MeanDataType*>(arg.p_workspace_mean_),
static_cast<VarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_),
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
......@@ -703,15 +695,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t numNormBlockTileIteration_N =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
avg_time += launch_and_time_kernel(stream_config,
avg_time +=
launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_welford_mean_grid_,
arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
static_cast<const MeanDataType*>(arg.p_workspace_mean_),
static_cast<const VarDataType*>(arg.p_workspace_var_),
static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
......@@ -746,6 +739,54 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(VarDataType) + 64;
// workspace for welford intermediate count
workspace_size += gemm_welford_size * sizeof(int32_t) + 64;
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = gemm_welford_size * sizeof(MeanDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = gemm_welford_size * sizeof(VarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count
pArg_->p_workspace_count_ =
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
};
static bool IsSupportedArgument(const Argument&)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment