Commit 7fb23099 authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Add checking for reduceDims in reference_batchnorm_backward

parent 9f490d1f
......@@ -59,7 +59,6 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
ignore = xStrides;
ignore = dyStrides;
ignore = dxStrides;
ignore = reduceDims;
ignore = bnScaleStrides;
ignore = bnBiasStrides;
ignore = bnMeanVarStrides;
......@@ -68,6 +67,9 @@ struct ReferenceBatchNormBwd_Input_N_H_W_C_Output_C
bnScaleBiasMeanVarLengths[0] != xyLengths[3])
throw std::runtime_error("Invalid tensor dimensions!");
if(reduceDims[0] != 0 || reduceDims[1] != 1 || reduceDims[2] != 2)
throw std::runtime_error("Invalid reduce dimensions!");
n_ = xyLengths[0];
h_ = xyLengths[1];
w_ = xyLengths[2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment