Unverified Commit d0c7b451 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Clip fp8 to +/-240 on all targets. (#1172)

* clip fp8 to +/-240 on all targets

* if inputs to fp8 conversion are +/-inf, they remain unaltered

* increase tolerance for test_elementwise_layernorm to prevent false errors

* change the input values for gemm examples to floats

* reduce gemm example float input values to prevent errors

* increase the tolerance for gemm examples
parent d9095997
......@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
int init_method = 2;
bool time_kernel = false;
};
......
......@@ -69,8 +69,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break;
default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n);
}
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......@@ -240,7 +240,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
return ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1);
#endif
}
......
......@@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
......@@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
......@@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__)
union
......@@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
......@@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{
#if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union
{
float fval;
......
......@@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
y_dev.FromDevice(y.mData.data());
bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3);
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 5e-3, 5e-3);
if(do_log)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment