Commit ff6bbf40 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Verify 04_gemm_add_add_fastgelu on floating point numbers

parent 52cd7ade
...@@ -57,7 +57,7 @@ struct ProblemSize final ...@@ -57,7 +57,7 @@ struct ProblemSize final
struct ExecutionConfig final struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
}; };
......
...@@ -7,7 +7,7 @@ using ADataType = BF16; ...@@ -7,7 +7,7 @@ using ADataType = BF16;
using BDataType = BF16; using BDataType = BF16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification using CDataType = F32; // C matrix doesn't exist in GPU memory, this is used for host verification
using D0DataType = BF16; using D0DataType = BF16;
using D1DataType = BF16; using D1DataType = BF16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
......
#pragma once #pragma once
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 2e-1;
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 2e-1;
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 2e-1;
}
else
{
return 1e-3;
}
}
bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionConfig& config) bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionConfig& config)
{ {
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
...@@ -150,7 +232,11 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC ...@@ -150,7 +232,11 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC
return ck::utils::check_err(e_m_n_device_result_converted, e_m_n_host_result); return ck::utils::check_err(e_m_n_device_result_converted, e_m_n_host_result);
#else #else
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); return ck::utils::check_err(e_m_n_device_result,
e_m_n_host_result,
"Error: Incorrect results!",
get_rtol<EDataType>(),
get_atol<EDataType>());
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment