Commit 7e2f7c95 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Enable build of example_gemm_xdl_fp8_bf8 test.

parent 043709c2
...@@ -291,6 +291,9 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) ...@@ -291,6 +291,9 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
} // namespace fp8_impl } // namespace fp8_impl
template <typename T, index_t N>
struct non_native_vector_base;
struct f8_ocp_t struct f8_ocp_t
{ {
using data_type = fp8_storage_t; using data_type = fp8_storage_t;
...@@ -336,8 +339,51 @@ struct f8_ocp_t ...@@ -336,8 +339,51 @@ struct f8_ocp_t
} }
}; };
template <typename T, index_t N> struct bf8_ocp_t
struct non_native_vector_base; {
using data_type = fp8_storage_t;
data_type data;
static constexpr ck_saturation_t default_saturation = CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret = CK_E5M2_OCP;
static constexpr unsigned int we = 5; // exponent width
static constexpr unsigned int wm = 2; // mantissa width
__host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const
#else
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
};
template <index_t N> template <index_t N>
struct non_native_vector_base<f8_ocp_t, N> struct non_native_vector_base<f8_ocp_t, N>
...@@ -383,50 +429,20 @@ struct non_native_vector_base<f8_ocp_t, 2> ...@@ -383,50 +429,20 @@ struct non_native_vector_base<f8_ocp_t, 2>
} }
}; };
struct bf8_ocp_t template <index_t N>
struct non_native_vector_base<bf8_ocp_t, N>
{ {
using data_type = fp8_storage_t; using data_t = bf8_ocp_t::data_type;
data_type data; using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t) * N)));
using type = non_native_vector_base<bf8_ocp_t, N>;
static constexpr ck_saturation_t default_saturation = CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret = CK_E5M2_OCP;
static constexpr unsigned int we = 5; // exponent width
static constexpr unsigned int wm = 2; // mantissa width
__host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
}
#if CK_USE_OCP_FP8 data_v d; // storage vector
__host__ __device__ explicit operator float() const
#else __host__ __device__ non_native_vector_base() = default;
__host__ explicit operator float() const __host__ __device__ non_native_vector_base(data_t a) : d{a} {}
#endif __host__ __device__ non_native_vector_base(data_v v) : d{v} {}
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8 __host__ __device__ operator data_v() const { return d; }
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
}; };
namespace fp8_impl { namespace fp8_impl {
......
...@@ -1036,10 +1036,6 @@ struct non_native_vector_base ...@@ -1036,10 +1036,6 @@ struct non_native_vector_base
template <typename T, index_t N> template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>; struct scalar_type<non_native_vector_base<T, N>>;
// {
// using type = T;
// static constexpr index_t vector_size = N;
// };
template <index_t N> template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>> struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
...@@ -1049,6 +1045,14 @@ struct scalar_type<non_native_vector_base<f8_ocp_t, N>> ...@@ -1049,6 +1045,14 @@ struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
static constexpr index_t vector_size = N; static constexpr index_t vector_size = N;
}; };
template <index_t N>
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
{
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation // non-native vector_type implementation
template <typename T> template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment