Commit 69e021cc authored by rocking's avatar rocking
Browse files

Fix bug of dx stride

parent 7d99da99
...@@ -239,8 +239,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType, ...@@ -239,8 +239,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
p_gamma_(p_gamma), p_gamma_(p_gamma),
p_mean_(p_mean), p_mean_(p_mean),
p_invStd_(p_invStd), p_invStd_(p_invStd),
p_dx_(p_dx), p_dx_(p_dx)
dxStrides_{dxStrides}
{ {
lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims); lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims); dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims);
...@@ -249,6 +248,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType, ...@@ -249,6 +248,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims); meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims);
invStdStrides_ = invStdStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims); shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
dxStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dxStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths_); std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths_);
...@@ -432,7 +432,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType, ...@@ -432,7 +432,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
reduceDims, reduceDims,
static_cast<const DYDataType*>(p_dy), static_cast<const DYDataType*>(p_dy),
static_cast<const XDataType*>(p_x), static_cast<const XDataType*>(p_x),
static_cast<const XDataType*>(p_gamma), static_cast<const GammaDataType*>(p_gamma),
static_cast<const MeanInvStdDataType*>(p_mean), static_cast<const MeanInvStdDataType*>(p_mean),
static_cast<const MeanInvStdDataType*>(p_invStd), static_cast<const MeanInvStdDataType*>(p_invStd),
static_cast<DXDataType*>(p_dx)); static_cast<DXDataType*>(p_dx));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment