Commit 69e021cc authored by rocking's avatar rocking
Browse files

Fix bug of dx stride

parent 7d99da99
......@@ -239,8 +239,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
p_gamma_(p_gamma),
p_mean_(p_mean),
p_invStd_(p_invStd),
p_dx_(p_dx),
dxStrides_{dxStrides}
p_dx_(p_dx)
{
lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims);
......@@ -249,6 +248,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims);
invStdStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
dxStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dxStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths_);
......@@ -432,7 +432,7 @@ struct DeviceNormalizationBwdXImpl : public DeviceNormalizationBwdX<DYDataType,
reduceDims,
static_cast<const DYDataType*>(p_dy),
static_cast<const XDataType*>(p_x),
static_cast<const XDataType*>(p_gamma),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const MeanInvStdDataType*>(p_mean),
static_cast<const MeanInvStdDataType*>(p_invStd),
static_cast<DXDataType*>(p_dx));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment