Unverified Commit 97042d87 authored by Andriy Roshchenko's avatar Andriy Roshchenko Committed by GitHub
Browse files

Implement `non_native_vector_base` with `ext_vector_type` array. (#232)

* Enable support of 1, 2, 4, and 8-byte custom types in CK.
parent 1084c64c
......@@ -293,9 +293,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
} // namespace fp8_impl
template <typename T, index_t N>
struct non_native_vector_base;
struct f8_ocp_t
{
using data_type = fp8_storage_t;
......@@ -389,111 +386,6 @@ struct bf8_ocp_t
}
};
template <index_t N>
struct non_native_vector_base<f8_ocp_t, N>
{
using data_t = f8_ocp_t::data_type;
static_assert(sizeof(f8_ocp_t) == sizeof(data_t),
"non_native_vector_base storage size mismatch");
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t) * N)));
using type = non_native_vector_base<f8_ocp_t, N>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(f8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
};
template <>
struct non_native_vector_base<f8_ocp_t, 1>
{
using data_t = f8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t))));
using type = non_native_vector_base<f8_ocp_t, 1>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(f8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
__host__ __device__ operator data_t() const { return d[0]; }
__host__ __device__ operator f8_ocp_t() const { return f8_ocp_t{d[0]}; }
};
template <>
struct non_native_vector_base<f8_ocp_t, 2>
{
using data_t = f8_ocp_t::data_type;
using type = non_native_vector_base<f8_ocp_t, 2>;
using data_v = fp8_impl::fp8x2_storage_t; // type of storage vector
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(f8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
using float2_t = fp8_impl::float2_t;
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float2_t() const
#else
__host__ explicit operator float2_t() const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(d);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[0]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[1])};
#endif
}
};
template <index_t N>
struct non_native_vector_base<bf8_ocp_t, N>
{
using data_t = bf8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t) * N)));
using type = non_native_vector_base<bf8_ocp_t, N>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
};
template <>
struct non_native_vector_base<bf8_ocp_t, 1>
{
using data_t = bf8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t))));
using type = non_native_vector_base<bf8_ocp_t, 1>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(bf8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
__host__ __device__ operator data_t() const { return d[0]; }
__host__ __device__ operator bf8_ocp_t() const { return bf8_ocp_t{d[0]}; }
};
template <typename T>
__host__ __device__ static inline constexpr bool fp8_is_nan(T);
......
......@@ -1024,17 +1024,124 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
}
};
template <typename T, index_t N, typename Enable = void>
struct non_native_vector_base;
template <typename T>
struct nnvb_data_t_selector
{
using type = unsigned _BitInt(8 * sizeof(T));
};
template <>
struct nnvb_data_t_selector<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
};
template <typename T, index_t N>
struct non_native_vector_base
struct non_native_vector_base<
T,
N,
std::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>
{
using type = non_native_vector_base<T, N>;
using data_t = typename nnvb_data_t_selector<T>::type; // select data_t based on the size of T
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
using data_v = data_t __attribute__((ext_vector_type(N)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v{a}} {}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const type&) = default;
__host__ __device__ non_native_vector_base(type&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
T d[N];
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
};
template <typename T, index_t N>
......@@ -1073,7 +1180,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
__host__ __device__ constexpr vector_type() : data_{d1_t{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{{v}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
......
......@@ -451,6 +451,20 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnu
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<1>{}])};
#endif
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment