Unverified Commit f46a6ffa authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Fix the fp8 gemm for large tensors on MI300. (#1011)



* Fix the fp8 conversion

* Try clipping value before conversion

* Fix return

* Simplify with a const

* reduce the gemm input tensor values to reduce round-off error

* replace if-else with lambda

* fix syntax

---------
Co-authored-by: default avatarRostyslav Geyyer <rosty.geyyer@amd.com>
parent 6fe0bc7e
......@@ -100,6 +100,8 @@ template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
......
......@@ -75,8 +75,8 @@ int profile_gemm_impl(int do_verification,
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 0.1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.05, 0.05});
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment