Commit 646b8e5c authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers

parent 405fdaec
......@@ -40,8 +40,8 @@ using BF8 = ck::bf8_t;
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
int init_method = 2;
bool time_kernel = false;
};
#define DefaultConvParams \
......
......@@ -10,6 +10,7 @@ using AccDataType = FP32;
using CShuffleDataType = FP16;
using DsDataType = ck::Tuple<>;
using InDataType = FP16;
using VerifyDataType = FP16; // is used for selection of check tolerances
using OutLayout = ck::tensor_layout::convolution::GNHWK;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
......@@ -12,6 +12,7 @@ using DsDataType = ck::Tuple<>;
using InDataType = FP16;
using AComputeType = BF8;
using BComputeType = FP8;
using VerifyDataType = BF8; // is used for selection of check tolerances
using OutLayout = ck::tensor_layout::convolution::GNHWK;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1;
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1;
}
else
{
return 1e-3;
}
}
bool run_conv_bwd_data(const ExecutionConfig& config,
const ck::utils::conv::ConvParam& conv_params,
......@@ -27,7 +108,7 @@ bool run_conv_bwd_data(const ExecutionConfig& config,
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{0.0, 1.0});
out.GenerateTensorValue(GeneratorTensor_3<OutDataType>{-0.5, 0.5});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
......@@ -138,7 +219,11 @@ bool run_conv_bwd_data(const ExecutionConfig& config,
in_device_buf.FromDevice(in_device.mData.data());
return ck::utils::check_err(in_device.mData, in_host.mData);
return ck::utils::check_err(in_device.mData,
in_host.mData,
"Error: Incorrect results!",
get_rtol<VerifyDataType>(),
get_atol<VerifyDataType>());
}
return true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment