Commit c13776be authored by rocking's avatar rocking
Browse files

1. Allocate mean, var and count into by SetWorkSpacePointer.

2. Add GetWorkSpaceSize to calculate the space size
parent 5215f11d
...@@ -175,6 +175,10 @@ int main() ...@@ -175,6 +175,10 @@ int main()
throw std::runtime_error("wrong! this device_op instance does not support this problem"); throw std::runtime_error("wrong! this device_op instance does not support this problem");
} }
size_t workspace_sz = device_op.GetWorkSpaceSize(&argument);
DeviceMem workspace_dev(workspace_sz);
device_op.SetWorkSpacePointer(&argument, workspace_dev.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
if(do_verification) if(do_verification)
......
...@@ -481,9 +481,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -481,9 +481,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_welford_mean_grid_{nullptr}, p_workspace_mean_{nullptr},
p_welford_var_grid_{nullptr}, p_workspace_var_{nullptr},
p_welford_count_grid_{nullptr}, p_workspace_count_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)}, p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)}, p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)}, p_h_grid_{static_cast<HDataType*>(p_h_grid)},
...@@ -510,14 +510,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -510,14 +510,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_nblock_ = mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_); DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
// TODO - GetWorkSpaceSize(), let user hipMalloc the memory
int gemm_welford_size = MRaw * gemm_nblock_;
hip_check_error(
hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(
hipMalloc(&p_welford_var_grid_, sizeof(VarDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_welford_count_grid_, sizeof(int32_t) * gemm_welford_size));
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
...@@ -568,9 +560,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -568,9 +560,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_; typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_; EDataType* p_e_grid_;
MeanDataType* p_welford_mean_grid_; void* p_workspace_mean_;
VarDataType* p_welford_var_grid_; void* p_workspace_var_;
int32_t* p_welford_count_grid_; void* p_workspace_count_;
const GammaDataType* p_gamma_grid_; const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_; const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_; HDataType* p_h_grid_;
...@@ -682,9 +674,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -682,9 +674,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_b_grid_, arg.p_b_grid_,
arg.p_ds_grid_, arg.p_ds_grid_,
arg.p_e_grid_, arg.p_e_grid_,
arg.p_welford_mean_grid_, static_cast<MeanDataType*>(arg.p_workspace_mean_),
arg.p_welford_var_grid_, static_cast<VarDataType*>(arg.p_workspace_var_),
arg.p_welford_count_grid_, static_cast<int32_t*>(arg.p_workspace_count_),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.cde_element_op_, arg.cde_element_op_,
...@@ -703,15 +695,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -703,15 +695,16 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t numNormBlockTileIteration_N = index_t numNormBlockTileIteration_N =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1)); math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
avg_time += launch_and_time_kernel(stream_config, avg_time +=
launch_and_time_kernel(stream_config,
kernel_welford_layernorm, kernel_welford_layernorm,
dim3(grid_size), dim3(grid_size),
dim3(BlockSize), dim3(BlockSize),
0, 0,
arg.p_e_grid_, arg.p_e_grid_,
arg.p_welford_mean_grid_, static_cast<const MeanDataType*>(arg.p_workspace_mean_),
arg.p_welford_var_grid_, static_cast<const VarDataType*>(arg.p_workspace_var_),
arg.p_welford_count_grid_, static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_gamma_grid_, arg.p_gamma_grid_,
arg.p_beta_grid_, arg.p_beta_grid_,
arg.p_h_grid_, arg.p_h_grid_,
...@@ -746,6 +739,54 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator ...@@ -746,6 +739,54 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
} }
}; };
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(VarDataType) + 64;
// workspace for welford intermediate count
workspace_size += gemm_welford_size * sizeof(int32_t) + 64;
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = gemm_welford_size * sizeof(MeanDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = gemm_welford_size * sizeof(VarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count
pArg_->p_workspace_count_ =
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
};
static bool IsSupportedArgument(const Argument&) static bool IsSupportedArgument(const Argument&)
{ {
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment