"vscode:/vscode.git/clone" did not exist on "6ff3fe5d059f98b518091144bde9df7b8456af4f"
Commit fe0ced87 authored by root's avatar root
Browse files

merge

parents dd3a5424 14dc7552
...@@ -150,8 +150,12 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) ...@@ -150,8 +150,12 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x); const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = utils::cast_from_f8<f8_t, float, negative_zero_nan>(f8x2_v.template AsType<f8_t>()[Number<0>{}]); f32x2_v.template AsType<float>()(Number<0>{}) =
f32x2_v.template AsType<float>()(Number<1>{}) = utils::cast_from_f8<f8_t, float, negative_zero_nan>(f8x2_v.template AsType<f8_t>()[Number<1>{}]); utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}]; return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif #endif
} }
...@@ -161,12 +165,11 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) ...@@ -161,12 +165,11 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{ {
const vector_type<float, 2> f32x2_v(x); const vector_type<float, 2> f32x2_v(x);
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}], f32x2_v.template AsType<float>()[Number<1>{}]); const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
f32x2_v.template AsType<float>()[Number<1>{}]);
return bit_cast<half2_t>(y); return bit_cast<half2_t>(y);
} }
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
......
...@@ -28,7 +28,8 @@ using S = ck::Sequence<Is...>; ...@@ -28,7 +28,8 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // static constexpr auto GemmMNPadding =
// ck::tensor_operation::device::GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple< using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment