Commit 325d0909 authored by rocking's avatar rocking
Browse files

Refine device normalization. Remove the block group size.

This variable is for splitK. I will create new class for normalization splitK
parent 6b4a1e5e
...@@ -20,6 +20,10 @@ namespace tensor_operation { ...@@ -20,6 +20,10 @@ namespace tensor_operation {
namespace device { namespace device {
// Y = Normalization(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
// M: Invarient length
// K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W
template <typename XDataType, template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
...@@ -68,7 +72,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -68,7 +72,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths, static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
const std::vector<index_t>& inStrides, const std::vector<index_t>& inStrides,
int blkGroupSize,
int numBlockTileIteration) int numBlockTileIteration)
{ {
constexpr index_t NumInvariantDim = Rank - NumReduceDim; constexpr index_t NumInvariantDim = Rank - NumReduceDim;
...@@ -117,10 +120,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -117,10 +120,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
const auto inPad_M = const auto inPad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
auto in_grid_desc_m_k_padded = transform_tensor_descriptor( auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
in_grid_desc_m_k, in_grid_desc_m_k,
...@@ -132,7 +134,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -132,7 +134,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
return (in_grid_desc_m_k_padded); return (in_grid_desc_m_k_padded);
}; };
using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -162,26 +164,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -162,26 +164,22 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims); gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims); betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
long_index_t invariant_total_length; long_index_t invariant_length;
long_index_t reduce_total_length; long_index_t reduce_length;
std::tie(invariant_total_length, reduce_total_length) = std::tie(invariant_length, reduce_length) =
get_2d_lengths<Rank, NumReduceDim>(Lengths_); get_2d_lengths<Rank, NumReduceDim>(Lengths_);
blkGroupSize_ = 1; numBlockTileIteration_ = math::integer_divide_ceil(reduce_length, K_BlockTileSize);
numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / gridSize_ = math::integer_divide_ceil(invariant_length, M_BlockTileSize);
M_BlockTileSize * blkGroupSize_;
x_grid_desc_m_k_ = x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_);
MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_);
gamma_grid_desc_m_k_ = gamma_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_);
beta_grid_desc_m_k_ = beta_grid_desc_m_k_ =
MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_); MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_);
y_grid_desc_m_k_ = y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_);
MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_);
isSweeponce_ = isSweeponce_ =
x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
...@@ -202,7 +200,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType, ...@@ -202,7 +200,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
YElementwiseOperation y_elementwise_op_; YElementwiseOperation y_elementwise_op_;
int blkGroupSize_;
int numBlockTileIteration_; int numBlockTileIteration_;
size_t gridSize_; size_t gridSize_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment