"...composable_kernel.git" did not exist on "d8f1458f448f4509d950ef04adc15eda622a9c5d"
Commit 7e2f7c95 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Enable build of example_gemm_xdl_fp8_bf8 test.

parent 043709c2
......@@ -291,6 +291,9 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
} // namespace fp8_impl
template <typename T, index_t N>
struct non_native_vector_base;
struct f8_ocp_t
{
using data_type = fp8_storage_t;
......@@ -336,8 +339,51 @@ struct f8_ocp_t
}
};
template <typename T, index_t N>
struct non_native_vector_base;
struct bf8_ocp_t
{
using data_type = fp8_storage_t;
data_type data;
static constexpr ck_saturation_t default_saturation = CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret = CK_E5M2_OCP;
static constexpr unsigned int we = 5; // exponent width
static constexpr unsigned int wm = 2; // mantissa width
__host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const
#else
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
};
template <index_t N>
struct non_native_vector_base<f8_ocp_t, N>
......@@ -383,50 +429,20 @@ struct non_native_vector_base<f8_ocp_t, 2>
}
};
struct bf8_ocp_t
template <index_t N>
struct non_native_vector_base<bf8_ocp_t, N>
{
using data_type = fp8_storage_t;
data_type data;
static constexpr ck_saturation_t default_saturation = CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret = CK_E5M2_OCP;
static constexpr unsigned int we = 5; // exponent width
static constexpr unsigned int wm = 2; // mantissa width
__host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
}
using data_t = bf8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t) * N)));
using type = non_native_vector_base<bf8_ocp_t, N>;
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const
data_v d; // storage vector
#else
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
__host__ __device__ operator data_v() const { return d; }
};
namespace fp8_impl {
......
......@@ -1036,10 +1036,6 @@ struct non_native_vector_base
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
// {
// using type = T;
// static constexpr index_t vector_size = N;
// };
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
......@@ -1049,6 +1045,14 @@ struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
{
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment