"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "6b45b1d6b4dff4e805d42c096c4db4e6d7be141e"
Unverified Commit 888317e6 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Fix universal gemm profiler for pk_i4_t (#1790)

* Fix universal gemm profiler for pk_i4_t

* fix
parent 37b35146
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) ...@@ -44,10 +44,19 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
else else
os << delim; os << delim;
if constexpr(std::is_same_v<T, ck::f8_t> || std::is_same_v<T, ck::bf8_t>) using RangeType = ck::remove_cvref_t<decltype(v)>;
if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
std::is_same_v<RangeType, ck::bhalf_t>)
{ {
os << ck::type_convert<float>(v); os << ck::type_convert<float>(v);
} }
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
<< vector_of_floats.template AsType<float>()[ck::Number<1>{}];
}
else else
{ {
os << static_cast<T>(v); os << static_cast<T>(v);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_ ...@@ -465,6 +465,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#endif #endif
} }
template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f32 = ck::type_convert<float>(x_l);
auto h_f32 = ck::type_convert<float>(x_h);
return {l_f32, h_f32};
}
template <> template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -177,7 +177,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -177,7 +177,7 @@ bool profile_gemm_universal_impl(int do_verification,
} }
} }
if(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>) if constexpr(is_same_v<BDataType, pk_i4_t> && is_same_v<ADataType, half_t>)
{ {
// vector pk_i4x4 permute // vector pk_i4x4 permute
for(int i = 0; i < N; i++) for(int i = 0; i < N; i++)
...@@ -188,7 +188,7 @@ bool profile_gemm_universal_impl(int do_verification, ...@@ -188,7 +188,7 @@ bool profile_gemm_universal_impl(int do_verification,
for(int k = 0; k < 4; k++) for(int k = 0; k < 4; k++)
{ {
int i4x2 = b_k_n_permute(j + k * 2, i); int i4x2 = b_k_n_permute(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf; input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf; input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment