Commit 7d99da99 authored by rocking's avatar rocking
Browse files

Fix bug of KRaw_ error

parent b7aa49a3
...@@ -250,7 +250,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType, ...@@ -250,7 +250,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
invStdStrides_ = invStdStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims); shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths); std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
...@@ -265,9 +265,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType, ...@@ -265,9 +265,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_); Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_);
dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_); dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_);
// TODO - sweep once for small k
isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize; isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize;
// isSweeponce_ = false;
} }
const DYDataType* p_dy_; const DYDataType* p_dy_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment