Commit 6fae79b6 authored by rocking's avatar rocking
Browse files

Check vector size

parent ee6cd44d
...@@ -101,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -101,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"); "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static_assert( static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"); "check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
......
...@@ -35,7 +35,7 @@ template <typename DYDataType, ...@@ -35,7 +35,7 @@ template <typename DYDataType,
index_t DBetaDstVectorSize> index_t DBetaDstVectorSize>
struct GridwiseNormalizationBwdGammaBeta_mk_to_k struct GridwiseNormalizationBwdGammaBeta_mk_to_k
{ {
// if we just check ThreadSliceSize & VectorSize == 0, the performance may be poor // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce)
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
(DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
...@@ -44,6 +44,14 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k ...@@ -44,6 +44,14 @@ struct GridwiseNormalizationBwdGammaBeta_mk_to_k
(XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"); "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)),
"Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize,
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>; using ThreadClusterLengths_M_K = Sequence<MThreadClusterSize, KThreadClusterSize>;
using DYThreadBufferDimAccessOrder = using DYThreadBufferDimAccessOrder =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment