Commit 9f490d1f authored by Qianfeng Zhang's avatar Qianfeng Zhang
Browse files

Renaming in backward example again

parent 59613285
......@@ -142,14 +142,14 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
Tensor<InOutDataType> dx_ref(inOutLengths);
Tensor<InOutDataType> dx(inOutLengths);
Tensor<AccDataType> bnScaleDiff(scaleBiasMeanVarLengths);
Tensor<AccDataType> bnBiasDiff(scaleBiasMeanVarLengths);
Tensor<AccDataType> dscale(scaleBiasMeanVarLengths);
Tensor<AccDataType> dbias(scaleBiasMeanVarLengths);
Tensor<AccDataType> bnScaleDiff_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> bnBiasDiff_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> dscale_ref(scaleBiasMeanVarLengths);
Tensor<AccDataType> dbias_ref(scaleBiasMeanVarLengths);
auto inOutStrides = dy.mDesc.GetStrides();
auto scaleBiasMeanVarStrides = bnScaleDiff.mDesc.GetStrides();
auto scaleBiasMeanVarStrides = dscale.mDesc.GetStrides();
std::size_t num_thread = std::thread::hardware_concurrency();
......@@ -226,8 +226,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
// output data of the batchnorm backward algorithm
DeviceMem dx_dev(sizeof(InOutDataType) * dx.mDesc.GetElementSpaceSize());
DeviceMem bnScaleDiff_dev(sizeof(AccDataType) * bnScaleDiff.mDesc.GetElementSpaceSize());
DeviceMem bnBiasDiff_dev(sizeof(AccDataType) * bnBiasDiff.mDesc.GetElementSpaceSize());
DeviceMem dscale_dev(sizeof(AccDataType) * dscale.mDesc.GetElementSpaceSize());
DeviceMem dbias_dev(sizeof(AccDataType) * dbias.mDesc.GetElementSpaceSize());
x_dev.ToDevice(x.mData.data());
dy_dev.ToDevice(dy.mData.data());
......@@ -300,8 +300,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
epsilon,
PassThroughOp{},
dx_dev.GetDeviceBuffer(),
bnScaleDiff_dev.GetDeviceBuffer(),
bnBiasDiff_dev.GetDeviceBuffer());
dscale_dev.GetDeviceBuffer(),
dbias_dev.GetDeviceBuffer());
if(!batchnorm_bwd.IsSupportedArgument(argument_ptr.get()))
{
......@@ -377,8 +377,8 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
epsilon,
PassThroughOp{},
dx_ref.mData.data(),
bnScaleDiff_ref.mData.data(),
bnBiasDiff_ref.mData.data());
dscale_ref.mData.data(),
dbias_ref.mData.data());
if(!batchNormBwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
{
......@@ -393,13 +393,12 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
dx_dev.FromDevice(dx.mData.data());
bnScaleDiff_dev.FromDevice(bnScaleDiff.data());
bnBiasDiff_dev.FromDevice(bnBiasDiff.data());
pass = pass && ck::utils::check_err(
bnBiasDiff.mData, bnBiasDiff_ref.mData, "BiasDiff result:", 1e-5, 1e-5);
dscale_dev.FromDevice(dscale.data());
dbias_dev.FromDevice(dbias.data());
pass =
pass && ck::utils::check_err(
bnScaleDiff.mData, bnScaleDiff_ref.mData, "ScaleDiff result:", 1e-5, 2e-4);
pass && ck::utils::check_err(dbias.mData, dbias_ref.mData, "dBias result:", 1e-5, 1e-5);
pass = pass &&
ck::utils::check_err(dscale.mData, dscale_ref.mData, "dScale result:", 1e-5, 2e-4);
pass = pass && ck::utils::check_err(dx.mData, dx_ref.mData, "dx result:");
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment