Commit c9efd987 authored by rocking's avatar rocking
Browse files

Move descriptor and sweeponce to argument for quick debugging

parent c54b7bc9
...@@ -250,6 +250,18 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -250,6 +250,18 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
M_BlockTileSize * blkGroupSize_; M_BlockTileSize * blkGroupSize_;
x_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_);
beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_);
y_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
} }
AccDataType epsilon_; AccDataType epsilon_;
...@@ -270,25 +282,20 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -270,25 +282,20 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
int blkGroupSize_; int blkGroupSize_;
int numBlockTileIteration_; int numBlockTileIteration_;
size_t gridSize_; size_t gridSize_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K beta_grid_desc_m_k_;
GridDesc_M_K y_grid_desc_m_k_;
bool isSweeponce_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
const auto x_grid_desc_m_k = MakeSrc2dDescriptor( const auto kernel_main = arg.isSweeponce_
arg.Lengths_, arg.xStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_); ? kernel_layernorm<GridwiseReduceLayernormSweepOnce,
const auto gamma_grid_desc_m_k = MakeSrc2dDescriptor(
arg.Lengths_, arg.gammaStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_);
const auto beta_grid_desc_m_k = MakeSrc2dDescriptor(
arg.Lengths_, arg.betaStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_);
const auto y_grid_desc_m_k = MakeSrc2dDescriptor(
arg.Lengths_, arg.yStrides_, arg.blkGroupSize_, arg.numBlockTileIteration_);
bool sweep_once =
x_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
const auto kernel_main = sweep_once ? kernel_layernorm<GridwiseReduceLayernormSweepOnce,
XDataType, XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
...@@ -311,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType, ...@@ -311,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
dim3(arg.gridSize_), dim3(arg.gridSize_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
x_grid_desc_m_k, arg.x_grid_desc_m_k_,
gamma_grid_desc_m_k, arg.gamma_grid_desc_m_k_,
beta_grid_desc_m_k, arg.beta_grid_desc_m_k_,
y_grid_desc_m_k, arg.y_grid_desc_m_k_,
arg.numBlockTileIteration_, arg.numBlockTileIteration_,
arg.epsilon_, arg.epsilon_,
arg.p_x_, arg.p_x_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment