Commit c9efd987 authored by rocking's avatar rocking
Browse files

Move descriptor and sweeponce to argument for quick debugging

parent c54b7bc9
...@@ -250,6 +250,18 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -250,6 +250,18 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize_; M_BlockTileSize * blkGroupSize_;
x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_);
beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
} }
AccDataType epsilon_; AccDataType epsilon_;
...@@ -270,40 +282,35 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -270,40 +282,35 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
int blkGroupSize_; int blkGroupSize_;
int numBlockTileIteration_; int numBlockTileIteration_;
size_t gridSize_; size_t gridSize_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_;
bool isSweeponce_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto x_grid_desc_m_k = MakeSrc2dDescriptor( const auto kernel_main = arg.isSweeponce_
arg.Lengths_, arg.xStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_); ? kernel_layernorm<GridwiseReduceLayernormSweepOnce,
const auto gamma_grid_desc_m_k = MakeSrc2dDescriptor( XDataType,
arg.Lengths_, arg.gammaStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_); GammaDataType,
const auto beta_grid_desc_m_k = MakeSrc2dDescriptor( BetaDataType,
arg.Lengths_, arg.betaStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_); YDataType,
const auto y_grid_desc_m_k = MakeSrc2dDescriptor( AccDataType,
arg.Lengths_, arg.yStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_); AccElementwiseOperation,
GridDesc_M_K>
bool sweep_once = : kernel_layernorm<GridwiseReduceLayernormGeneric,
x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; XDataType,
GammaDataType,
const auto kernel_main = sweep_once ? kernel_layernorm<GridwiseReduceLayernormSweepOnce, BetaDataType,
XDataType, YDataType,
GammaDataType, AccDataType,
BetaDataType, AccElementwiseOperation,
YDataType, GridDesc_M_K>;
AccDataType,
AccElementwiseOperation,
GridDesc_M_K>
: kernel_layernorm<GridwiseReduceLayernormGeneric,
XDataType,
GammaDataType,
BetaDataType,
YDataType,
AccDataType,
AccElementwiseOperation,
GridDesc_M_K>;
float avg_time = 0; float avg_time = 0;
avg_time += launch_and_time_kernel(stream_config, avg_time += launch_and_time_kernel(stream_config,
...@@ -311,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -311,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
dim3(arg.gridSize_), dim3(arg.gridSize_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
x_grid_desc_m_k, arg.x_grid_desc_m_k_,
gamma_grid_desc_m_k, arg.gamma_grid_desc_m_k_,
beta_grid_desc_m_k, arg.beta_grid_desc_m_k_,
y_grid_desc_m_k, arg.y_grid_desc_m_k_,
arg.numBlockTileIteration_, arg.numBlockTileIteration_,
arg.epsilon_, arg.epsilon_,
arg.p_x_, arg.p_x_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment