Commit 739d3db9 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Enable OCP build of example_gemm_xdl_fp8.

parent f1fe1ce6
......@@ -143,8 +143,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
switch(config.init_method)
{
case 0:
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.f)}(a_m_k);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.f)}(b_k_n);
break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
......
......@@ -549,8 +549,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, fp8_storage_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......
......@@ -28,6 +28,12 @@ using bf8_fnuz_t = unsigned _BitInt(8);
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OFP8_CVT_FAST_PATH 1
#else
#define CK_OFP8_CVT_FAST_PATH 0
#endif
typedef unsigned char fp8_storage_t;
/**
......@@ -52,6 +58,9 @@ enum ck_saturation_t
namespace fp8_impl {
typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
typedef float float2_t __attribute__((ext_vector_type(2)));
__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
{
return static_cast<unsigned char>(a) == 0x80;
......@@ -250,6 +259,33 @@ static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
{
// union
// {
// unsigned int i32val;
// unsigned short i16val[2];
// } val;
// val.i16val[0] = v;
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == CK_E4M3_FNUZ || interpret == CK_E4M3_OCP ||
interpret == CK_E5M2_FNUZ || interpret == CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == CK_E4M3_FNUZ) || (interpret == CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
}
else
{
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
}
}
#endif
} // namespace fp8_impl
......@@ -276,7 +312,7 @@ struct f8_ocp_t
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if CK_OFP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
......@@ -290,7 +326,7 @@ struct f8_ocp_t
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
#if CK_OFP8_CVT_FAST_PATH
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
......@@ -299,6 +335,53 @@ struct f8_ocp_t
}
};
template <typename T, index_t N>
struct non_native_vector_base;
template <index_t N>
struct non_native_vector_base<f8_ocp_t, N>
{
using data_t = f8_ocp_t::data_type;
using data_v = data_t __attribute__((ext_vector_type(sizeof(data_t) * N)));
using type = non_native_vector_base<f8_ocp_t, N>;
data_v d; // storage vector
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(data_t a) : d{a} {}
__host__ __device__ non_native_vector_base(data_v v) : d{v} {}
__host__ __device__ operator data_v() const { return d; }
};
template <>
struct non_native_vector_base<f8_ocp_t, 2>
{
using data_t = f8_ocp_t::data_type;
using type = non_native_vector_base<f8_ocp_t, 2>;
__host__ __device__ non_native_vector_base() = default;
using data_v = fp8_impl::fp8x2_storage_t; // type of storage vector
data_v d; // storage vector
using float2_t = fp8_impl::float2_t;
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float2_t() const
#else
__host__ explicit operator float2_t() const
#endif
{
#if CK_OFP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(d);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[0]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(d[1])};
#endif
}
};
struct bf8_ocp_t
{
using data_type = fp8_storage_t;
......
......@@ -1031,8 +1031,22 @@ struct non_native_vector_base
__host__ __device__ non_native_vector_base() = default;
typedef char data_v __attribute__((ext_vector_type(sizeof(T) * N)));
data_v d;
T d[N];
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
// {
// using type = T;
// static constexpr index_t vector_size = N;
// };
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
{
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation
......
......@@ -404,6 +404,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp32
template <>
inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
......@@ -461,6 +472,17 @@ inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_ocp_t>(x);
#else
return f8_convert_rne<f8_ocp_t>(x);
#endif
}
// convert fp8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
......@@ -485,6 +507,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
......@@ -512,6 +545,17 @@ inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
......
......@@ -62,9 +62,9 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
AccDataType v_acc = 0;
ComputeTypeA v_a = 0;
ComputeTypeB v_b = 0;
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
for(int k = 0; k < K; ++k)
{
......@@ -93,7 +93,7 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c = 0;
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc);
......
......@@ -25,17 +25,17 @@ template <typename ALayout,
typename ComputeTypeB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
naive_gemm_kernel(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
index_t m,
index_t n,
index_t k,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
naive_gemm_kernel(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
index_t m,
index_t n,
index_t k,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
using RowMajor = ck::tensor_layout::gemm::RowMajor;
......@@ -45,10 +45,10 @@ __global__ void
if(row_idx < m && col_idx < n)
{
AccDataType v_acc = static_cast<AccDataType>(0.0);
ComputeTypeA v_a = static_cast<ComputeTypeA>(0.0);
ComputeTypeB v_b = static_cast<ComputeTypeB>(0.0);
CDataType v_c = static_cast<CDataType>(0.0);
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
CDataType v_c{0};
for(int k_idx = 0; k_idx < k; ++k_idx)
{
......
......@@ -37,7 +37,7 @@ struct GeneratorTensor_1<ck::half_t>
float value = 1.0;
template <typename... Is>
ck::bhalf_t operator()(Is...)
ck::half_t operator()(Is...)
{
return ck::type_convert<ck::half_t>(value);
}
......@@ -62,7 +62,7 @@ struct GeneratorTensor_1<ck::f8_t>
float value = 1.0;
template <typename... Is>
ck::bhalf_t operator()(Is...)
ck::f8_t operator()(Is...)
{
return ck::type_convert<ck::f8_t>(value);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment