"vscode:/vscode.git/clone" did not exist on "ae99bcb9b647a872057cefb34407b143b81de576"
Commit c13776be authored by rocking's avatar rocking
Browse files

1. Allocate mean, var and count into by SetWorkSpacePointer.

2. Add GetWorkSpaceSize to calculate the space size
parent 5215f11d
......@@ -175,6 +175,10 @@ int main()
throw std::runtime_error("wrong! this device_op instance does not support this problem");
}
size_t workspace_sz = device_op.GetWorkSpaceSize(&argument);
DeviceMem workspace_dev(workspace_sz);
device_op.SetWorkSpacePointer(&argument, workspace_dev.GetDeviceBuffer());
invoker.Run(argument, StreamConfig{nullptr, false});
if(do_verification)
......
......@@ -481,9 +481,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
p_welford_mean_grid_{nullptr},
p_welford_var_grid_{nullptr},
p_welford_count_grid_{nullptr},
p_workspace_mean_{nullptr},
p_workspace_var_{nullptr},
p_workspace_count_{nullptr},
p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
p_h_grid_{static_cast<HDataType*>(p_h_grid)},
......@@ -510,14 +510,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_m_nblock_ =
DeviceOp::MakeMeanVarCountGridDescriptor_M_NBlock(MRaw, gemm_nblock_);
// TODO - GetWorkSpaceSize(), let user hipMalloc the memory
int gemm_welford_size = MRaw * gemm_nblock_;
hip_check_error(
hipMalloc(&p_welford_mean_grid_, sizeof(MeanDataType) * gemm_welford_size));
hip_check_error(
hipMalloc(&p_welford_var_grid_, sizeof(VarDataType) * gemm_welford_size));
hip_check_error(hipMalloc(&p_welford_count_grid_, sizeof(int32_t) * gemm_welford_size));
// populate pointer, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
......@@ -568,9 +560,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const BDataType* p_b_grid_;
typename GridwiseGemmWelford::DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
MeanDataType* p_welford_mean_grid_;
VarDataType* p_welford_var_grid_;
int32_t* p_welford_count_grid_;
void* p_workspace_mean_;
void* p_workspace_var_;
void* p_workspace_count_;
const GammaDataType* p_gamma_grid_;
const BetaDataType* p_beta_grid_;
HDataType* p_h_grid_;
......@@ -682,9 +674,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.p_welford_mean_grid_,
arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
static_cast<MeanDataType*>(arg.p_workspace_mean_),
static_cast<VarDataType*>(arg.p_workspace_var_),
static_cast<int32_t*>(arg.p_workspace_count_),
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
......@@ -703,27 +695,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
index_t numNormBlockTileIteration_N =
math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(I1));
avg_time += launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
arg.p_welford_mean_grid_,
arg.p_welford_var_grid_,
arg.p_welford_count_grid_,
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N,
arg.epsilon_,
arg.h_element_op_);
avg_time +=
launch_and_time_kernel(stream_config,
kernel_welford_layernorm,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_e_grid_,
static_cast<const MeanDataType*>(arg.p_workspace_mean_),
static_cast<const VarDataType*>(arg.p_workspace_var_),
static_cast<const int32_t*>(arg.p_workspace_count_),
arg.p_gamma_grid_,
arg.p_beta_grid_,
arg.p_h_grid_,
arg.e_grid_desc_m_n_,
arg.h_grid_desc_m_n_,
arg.mean_var_count_grid_desc_m_nblock_,
arg.gamma_grid_desc_n_,
arg.beta_grid_desc_n_,
numMeanVarCountBlockTileIteration_N,
numNormBlockTileIteration_N,
arg.epsilon_,
arg.h_element_op_);
return avg_time;
};
......@@ -746,6 +739,54 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
};
size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
{
const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
size_t workspace_size = 0;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(MeanDataType) + 64;
// workspace for welford intermediate mean
workspace_size += gemm_welford_size * sizeof(VarDataType) + 64;
// workspace for welford intermediate count
workspace_size += gemm_welford_size * sizeof(int32_t) + 64;
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
pArg_->p_workspace_ = p_workspace;
int gemm_welford_size = pArg_->mean_var_count_grid_desc_m_nblock_.GetElementSpaceSize();
// int gemm_welford_size = MRaw * pArg->gemm_nblock_;
// setup buffer used for intermediate welford mean
pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
index_t mean_space_sz = gemm_welford_size * sizeof(MeanDataType);
mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
// setup buffer used for intermediate welford varirance
pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
index_t variance_space_sz = gemm_welford_size * sizeof(VarDataType);
variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
// setup buffer used for intermediate welford count
pArg_->p_workspace_count_ =
reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
};
static bool IsSupportedArgument(const Argument&)
{
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment