Commit 739d3db9 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Enable OCP build of example_gemm_xdl_fp8.

parent f1fe1ce6
...@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
switch(config.init_method) switch(config.init_method)
{ {
case 0: case 0:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k); ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n); ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
break; break;
case 1: case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
......
...@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, fp8_storage_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
......
...@@ -28,6 +28,12 @@ using bf8_fnuz_t = unsigned _BitInt(8); ...@@ -28,6 +28,12 @@ using bf8_fnuz_t = unsigned _BitInt(8);
#define CK_FP8_CVT_FAST_PATH 0 #define CK_FP8_CVT_FAST_PATH 0
#endif #endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OFP8_CVT_FAST_PATH 1
#else
#define CK_OFP8_CVT_FAST_PATH 0
#endif
typedef unsigned char fp8_storage_t; typedef unsigned char fp8_storage_t;
/** /**
...@@ -52,6 +58,9 @@ enum ck_saturation_t ...@@ -52,6 +58,9 @@ enum ck_saturation_t
namespace fp8_impl { namespace fp8_impl {
typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
typedef float float2_t __attribute__((ext_vector_type(2)));
__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a) __host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
{ {
return static_cast<unsigned char>(a) == 0x80; return static_cast<unsigned char>(a) == 0x80;
...@@ -250,6 +259,33 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v) ...@@ -250,6 +259,33 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
} }
} }
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
{
// union
// {
// unsigned int i32val;
// unsigned short i16val[2];
// } val;
// val.i16val[0] = v;
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == CK_E4M3_FNUZ || interpret == CK_E4M3_OCP ||
interpret == CK_E5M2_FNUZ || interpret == CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
}
else
{
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
}
}
#endif #endif
} // namespace fp8_impl } // namespace fp8_impl
...@@ -276,7 +312,7 @@ struct f8_ocp_t ...@@ -276,7 +312,7 @@ struct f8_ocp_t
__host__ explicit operator float() const __host__ explicit operator float() const
#endif #endif
{ {
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) #if CK_OFP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data); return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else #else
return fp8_impl::cast_from_f8<float, wm, we, false>( return fp8_impl::cast_from_f8<float, wm, we, false>(
...@@ -290,7 +326,7 @@ struct f8_ocp_t ...@@ -290,7 +326,7 @@ struct f8_ocp_t
__host__ explicit operator _Float16() const __host__ explicit operator _Float16() const
#endif #endif
{ {
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) #if CK_OFP8_CVT_FAST_PATH
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data)); return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else #else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>( return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
...@@ -299,6 +335,53 @@ struct f8_ocp_t ...@@ -299,6 +335,53 @@ struct f8_ocp_t
} }
}; };
template <typename T, index_t N>
struct non_native_vector_base;
template <index_t N>
struct non_native_vector_base<f8_ocp_t, N>
{
using data_t = f8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t) * N)));
using type = non_native_vector_base<f8_ocp_t, N>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
};
template <>
struct non_native_vector_base<f8_ocp_t, 2>
{
using data_t = f8_ocp_t::data_type;
using type = non_native_vector_base<f8_ocp_t, 2>;
__host__ __device__ non_native_vector_base() = default;
using data_v = fp8_impl::fp8x2_storage_t; // type of storage vector
data_v d; // storage vector
using float2_t = fp8_impl::float2_t;
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float2_t() const
#else
__host__ explicit operator float2_t() const
#endif
{
#if CK_OFP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(d);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[0]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[1])};
#endif
}
};
struct bf8_ocp_t struct bf8_ocp_t
{ {
using data_type = fp8_storage_t; using data_type = fp8_storage_t;
......
...@@ -1031,8 +1031,22 @@ struct non_native_vector_base ...@@ -1031,8 +1031,22 @@ struct non_native_vector_base
__host__ __device__ non_native_vector_base() = default; __host__ __device__ non_native_vector_base() = default;
typedef char data_v __attribute__((ext_vector_type(sizeof(T) * N))); T d[N];
data_v d; };
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
// {
// using type = T;
// static constexpr index_t vector_size = N;
// };
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
{
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
}; };
// non-native vector_type implementation // non-native vector_type implementation
......
...@@ -404,6 +404,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x) ...@@ -404,6 +404,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
#endif #endif
} }
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp32 // convert fp8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x) inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
...@@ -461,6 +472,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x) ...@@ -461,6 +472,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
#endif #endif
} }
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp16 // convert fp8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x) inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
...@@ -485,6 +507,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x) ...@@ -485,6 +507,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
#endif #endif
} }
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp32 // convert bf8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x) inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
...@@ -512,6 +545,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x) ...@@ -512,6 +545,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
#endif #endif
} }
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp16 // convert bf8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x) inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
......
...@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0; AccDataType v_acc{0};
ComputeTypeA v_a = 0; ComputeTypeA v_a{0};
ComputeTypeB v_b = 0; ComputeTypeB v_b{0};
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
...@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
CDataType v_c = 0; CDataType v_c{0};
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
......
...@@ -25,17 +25,17 @@ template <typename ALayout, ...@@ -25,17 +25,17 @@ template <typename ALayout,
typename ComputeTypeB> typename ComputeTypeB>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, naive_gemm_kernel(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid, const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid, CDataType* __restrict__ p_c_grid,
index_t m, index_t m,
index_t n, index_t n,
index_t k, index_t k,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op) const CDEElementwiseOperation c_element_op)
{ {
using RowMajor = ck::tensor_layout::gemm::RowMajor; using RowMajor = ck::tensor_layout::gemm::RowMajor;
...@@ -45,10 +45,10 @@ __global__ void ...@@ -45,10 +45,10 @@ __global__ void
if(row_idx < m && col_idx < n) if(row_idx < m && col_idx < n)
{ {
AccDataType v_acc = static_cast<AccDataType>(0.0); AccDataType v_acc{0};
ComputeTypeA v_a = static_cast<ComputeTypeA>(0.0); ComputeTypeA v_a{0};
ComputeTypeB v_b = static_cast<ComputeTypeB>(0.0); ComputeTypeB v_b{0};
CDataType v_c = static_cast<CDataType>(0.0); CDataType v_c{0};
for(int k_idx = 0; k_idx < k; ++k_idx) for(int k_idx = 0; k_idx < k; ++k_idx)
{ {
......
...@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t> ...@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
float value = 1.0; float value = 1.0;
template <typename... Is> template <typename... Is>
ck::bhalf_t operator()(Is...) ck::half_t operator()(Is...)
{ {
return ck::type_convert<ck::half_t>(value); return ck::type_convert<ck::half_t>(value);
} }
...@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t> ...@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
float value = 1.0; float value = 1.0;
template <typename... Is> template <typename... Is>
ck::bhalf_t operator()(Is...) ck::f8_t operator()(Is...)
{ {
return ck::type_convert<ck::f8_t>(value); return ck::type_convert<ck::f8_t>(value);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment