"docs/vscode:/vscode.git/clone" did not exist on "5202456b4c76fee2a2e80184c8d5112dd26911d0"
Unverified Commit d0c7b451 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Clip fp8 to +/-240 on all targets. (#1172)

* clip fp8 to +/-240 on all targets

* if inputs to fp8 conversion are +/-inf, they remain unaltered

* increase tolerance for test_elementwise_layernorm to prevent false errors

* change the input values for gemm examples to floats

* reduce gemm example float input values to prevent errors

* increase the tolerance for gemm examples
parent d9095997
...@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final ...@@ -49,7 +49,7 @@ struct ProblemSizeStreamK final
struct ExecutionConfig final struct ExecutionConfig final
{ {
bool do_verification = true; bool do_verification = true;
int init_method = 1; int init_method = 2;
bool time_kernel = false; bool time_kernel = false;
}; };
......
...@@ -69,8 +69,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -69,8 +69,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
break; break;
default: default:
ck::utils::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution<ADataType>{-0.1f, 0.1f}(a_m_k);
ck::utils::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n); ck::utils::FillUniformDistribution<BDataType>{-0.1f, 0.1f}(b_k_n);
} }
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
...@@ -240,7 +240,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -240,7 +240,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
#else #else
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); return ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1);
#endif #endif
} }
......
...@@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); ...@@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{ {
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__)
float max_fp8 = 240.0f; float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union union
{ {
float fval; float fval;
...@@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
...@@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{ {
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) ...@@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
...@@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); ...@@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
{ {
#if defined(__gfx94__)
float max_fp8 = 240.0f; float max_fp8 = 240.0f;
if(!std::isinf(x))
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
#if defined(__gfx94__)
union union
{ {
float fval; float fval;
......
...@@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, ...@@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
y_dev.FromDevice(y.mData.data()); y_dev.FromDevice(y.mData.data());
bool pass = bool pass =
ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 5e-3, 5e-3);
if(do_log) if(do_log)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment