Commit fe0ced87 authored by root's avatar root
Browse files

merge

parents dd3a5424 14dc7552
...@@ -29,7 +29,7 @@ struct PassThrough ...@@ -29,7 +29,7 @@ struct PassThrough
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const
{ {
auto t = type_convert<float2_t>(x); auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t); y = type_convert<half2_t>(t);
} }
......
...@@ -148,10 +148,14 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) ...@@ -148,10 +148,14 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x); const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = utils::cast_from_f8<f8_t, float, negative_zero_nan>(f8x2_v.template AsType<f8_t>()[Number<0>{}]); f32x2_v.template AsType<float>()(Number<0>{}) =
f32x2_v.template AsType<float>()(Number<1>{}) = utils::cast_from_f8<f8_t, float, negative_zero_nan>(f8x2_v.template AsType<f8_t>()[Number<1>{}]); utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}]; return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif #endif
} }
...@@ -160,13 +164,12 @@ template <> ...@@ -160,13 +164,12 @@ template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{ {
const vector_type<float, 2> f32x2_v(x); const vector_type<float, 2> f32x2_v(x);
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}], f32x2_v.template AsType<float>()[Number<1>{}]); const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
return bit_cast<half2_t>(y); f32x2_v.template AsType<float>()[Number<1>{}]);
return bit_cast<half2_t>(y);
} }
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
......
...@@ -28,7 +28,8 @@ using S = ck::Sequence<Is...>; ...@@ -28,7 +28,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple< using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment