Commit 1a96b749 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Improve infrastructure for OFP8 data type support.

parent 974d67f2
...@@ -845,9 +845,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, ...@@ -845,9 +845,9 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#else #else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(0); return src_thread_element_valid ? tmp : vector_t{0};
#endif #endif
} }
...@@ -875,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, ...@@ -875,8 +875,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size; constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>( vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0); src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(customized_value); return src_thread_element_valid ? tmp : vector_t(customized_value);
} }
......
...@@ -396,11 +396,31 @@ struct non_native_vector_base<f8_ocp_t, N> ...@@ -396,11 +396,31 @@ struct non_native_vector_base<f8_ocp_t, N>
__host__ __device__ non_native_vector_base() = default; __host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {} __host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(f8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {} __host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; } __host__ __device__ operator data_v() const { return d; }
}; };
template <>
struct non_native_vector_base<f8_ocp_t, 1>
{
using data_t = f8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t))));
using type = non_native_vector_base<f8_ocp_t, 1>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(f8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
__host__ __device__ operator data_t() const { return d[0]; }
__host__ __device__ operator f8_ocp_t() const { return f8_ocp_t{d[0]}; }
};
template <> template <>
struct non_native_vector_base<f8_ocp_t, 2> struct non_native_vector_base<f8_ocp_t, 2>
{ {
...@@ -411,6 +431,10 @@ struct non_native_vector_base<f8_ocp_t, 2> ...@@ -411,6 +431,10 @@ struct non_native_vector_base<f8_ocp_t, 2>
__host__ __device__ non_native_vector_base() = default; __host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {} __host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(f8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
using float2_t = fp8_impl::float2_t; using float2_t = fp8_impl::float2_t;
...@@ -445,6 +469,24 @@ struct non_native_vector_base<bf8_ocp_t, N> ...@@ -445,6 +469,24 @@ struct non_native_vector_base<bf8_ocp_t, N>
__host__ __device__ operator data_v() const { return d; } __host__ __device__ operator data_v() const { return d; }
}; };
template <>
struct non_native_vector_base<bf8_ocp_t, 1>
{
using data_t = bf8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t))));
using type = non_native_vector_base<bf8_ocp_t, 1>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(bf8_ocp_t f) : non_native_vector_base(f.data) {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
__host__ __device__ operator data_t() const { return d[0]; }
};
namespace fp8_impl { namespace fp8_impl {
template <typename T, template <typename T,
std::enable_if_t<std::is_same_v<T, bf8_ocp_t> || std::is_same_v<T, f8_ocp_t> || std::enable_if_t<std::is_same_v<T, bf8_ocp_t> || std::is_same_v<T, f8_ocp_t> ||
......
...@@ -1057,43 +1057,61 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>> ...@@ -1057,43 +1057,61 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
template <typename T> template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using type = d1_t; using d1_nnv_t = non_native_vector_base<T, 1>;
using type = d1_nnv_t;
union alignas(next_pow2(1 * sizeof(T))) union alignas(next_pow2(1 * sizeof(T)))
{ {
d1_t d1_; d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_; StaticallyIndexedArray<d1_t, 1> d1x1_;
d1_nnv_t d1_nnv_;
StaticallyIndexedArray<d1_nnv_t, 1> d1nnvx1_;
} data_; } data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {} __host__ __device__ constexpr vector_type() : data_{d1_t{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {} __host__ __device__ constexpr vector_type(type v) : data_{{v}} {}
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types."); "Something went wrong, please check src and dst types.");
return data_.d1x1_; if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
} }
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types."); "Something went wrong, please check src and dst types.");
return data_.d1x1_; if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
} }
}; };
template <typename T> template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d2_t = non_native_vector_base<T, 2>; using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using type = d2_t; using type = d2_t;
...@@ -1101,6 +1119,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1101,6 +1119,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
{ {
d2_t d2_; d2_t d2_;
StaticallyIndexedArray<d1_t, 2> d1x2_; StaticallyIndexedArray<d1_t, 2> d1x2_;
StaticallyIndexedArray<d1_nnv_t, 2> d1nnvx2_;
StaticallyIndexedArray<d2_t, 1> d2x1_; StaticallyIndexedArray<d2_t, 1> d2x1_;
} data_; } data_;
...@@ -1111,10 +1130,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1111,10 +1130,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types."); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{ {
return data_.d1x2_; return data_.d1x2_;
} }
...@@ -1131,10 +1151,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1131,10 +1151,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types."); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{ {
return data_.d1x2_; return data_.d1x2_;
} }
...@@ -1222,10 +1243,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1222,10 +1243,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename T> template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d2_t = non_native_vector_base<T, 2>; using d1_nnv_t = non_native_vector_base<T, 1>;
using d4_t = non_native_vector_base<T, 4>; using d2_t = non_native_vector_base<T, 2>;
using d8_t = non_native_vector_base<T, 8>; using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using type = d8_t; using type = d8_t;
...@@ -1233,6 +1255,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1233,6 +1255,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
{ {
d8_t d8_; d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_; StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d1_nnv_t, 8> d1nnvx8_;
StaticallyIndexedArray<d2_t, 4> d2x4_; StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_; StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_; StaticallyIndexedArray<d8_t, 1> d8x1_;
...@@ -1245,11 +1268,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1245,11 +1268,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X> template <typename X>
__host__ __device__ constexpr const auto& AsType() const __host__ __device__ constexpr const auto& AsType() const
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types."); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{ {
return data_.d1x8_; return data_.d1x8_;
} }
...@@ -1274,11 +1298,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1274,11 +1298,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X> template <typename X>
__host__ __device__ constexpr auto& AsType() __host__ __device__ constexpr auto& AsType()
{ {
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value, is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types."); "Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value) if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{ {
return data_.d1x8_; return data_.d1x8_;
} }
......
...@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, ...@@ -150,7 +150,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1}); d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1<D0DataType>{1});
} }
......
...@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, ...@@ -157,7 +157,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
...@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, ...@@ -174,7 +174,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
...@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, ...@@ -140,7 +140,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
break; break;
default: default:
a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{}); b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment