Unverified Commit 9ee69dd2 authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Fix pk_int4 cast and add pk_int4 dtype in ck tile (#1854)

* Fix pk_int4 cast and add pk_int4 dtype in ck tile

* fixes

* Improvements

* fix typo
parent 9c5b2f39
......@@ -163,6 +163,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
// shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
......@@ -16,7 +16,8 @@ namespace ck {
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q)
// Convert lower part of packed int4 -> int4 to half
__device__ inline half4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
......@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
return res.template AsType<half4_t>()[Number<0>{}];
}
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
......@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
return res.template AsType<half4_t>()[Number<0>{}];
}
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
vector_type<half_t, 2> res;
half_t x_h = (x_u8 & 0x0f) - 8;
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
res.template AsType<half_t>()(Number<0>{}) = x_l;
res.template AsType<half_t>()(Number<1>{}) = x_h;
return res.template AsType<half2_t>()[Number<0>{}];
#endif
}
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
__device__ inline bhalf4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
......@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
vector_type<bhalf_t, 2> res;
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
return res.template AsType<bhalf2_t>()[Number<0>{}];
}
namespace tensor_operation {
namespace element_wise {
......@@ -159,11 +118,11 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = i4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
......@@ -171,13 +130,13 @@ struct PassThroughPack8
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
......@@ -185,11 +144,11 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
......@@ -197,13 +156,13 @@ struct PassThroughPack8
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<bhalf2_t>()(Number<0>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<bhalf2_t>()(Number<1>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<bhalf2_t>()(Number<2>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<bhalf2_t>()(Number<3>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
......@@ -219,12 +178,12 @@ struct DequantPack8
__host__ __device__ constexpr void
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<1>{}) =
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
......@@ -232,13 +191,13 @@ struct DequantPack8
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
......@@ -260,7 +219,7 @@ struct PassThroughPack2
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
......
......@@ -7,6 +7,8 @@
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
#include "ck/utility/amd_inline_asm.hpp"
#include "ck/utility/type.hpp"
namespace ck {
// Define the common macro for MI300 models
......@@ -14,6 +16,26 @@ namespace ck {
#define __gfx94__
#endif
namespace {
namespace details {
[[maybe_unused]] __host__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
half2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
[[maybe_unused]] __device__ half2_t pk_add_f16(const half2_t& x, const half2_t& y)
{
return amd_assembly_pk_add_f16(x, y);
}
} // namespace details
} // namespace
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
......@@ -520,13 +542,51 @@ template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f32 = ck::type_convert<float>(x_l);
auto h_f32 = ck::type_convert<float>(x_h);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
float2_t res = {x_h, x_l};
#elif
float2_t res = {x_l, x_h};
#endif
return res;
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#else
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return details::pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
}
template <>
inline __host__ __device__ bhalf2_t type_convert<bhalf2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
bhalf2_t res = {type_convert<bhalf_t>(x_h), type_convert<bhalf_t>(x_l)};
#else
bhalf2_t res = {type_convert<bhalf_t>(x_l), type_convert<bhalf_t>(x_h)};
#endif
return {l_f32, h_f32};
return res;
}
template <>
......
......@@ -27,6 +27,7 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
......
......@@ -144,6 +144,10 @@
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1
#endif
#ifndef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
#define CK_TILE_USE_PK4_LAYOUT_SHUFFLE 1
#endif
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#include "ck_tile/core/numeric/int8.hpp"
#pragma once
namespace ck_tile {
// Packed 2xint4
struct pk_int4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_int4_t() : data{type{}} {}
__host__ __device__ constexpr pk_int4_t(type init) : data{init} {}
};
// limits
template <class T>
struct numeric;
template <>
struct numeric<pk_int4_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
{
constexpr uint8_t val = 0b10001000;
return pk_int4_t(bit_cast<int8_t>(val));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t lowest()
{
constexpr uint8_t val = 0b10001000;
return pk_int4_t(bit_cast<int8_t>(val));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
{
constexpr uint8_t val = 0b01110111;
return pk_int4_t(bit_cast<int8_t>(val));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr pk_int4_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr pk_int4_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr pk_int4_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr pk_int4_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr pk_int4_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
};
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
fp32x2_t res = {x_h, x_l};
#elif
fp32x2_t res = {x_l, x_h};
#endif
return res;
}
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
#elif
uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
#endif
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
}
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t& x)
{
uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
#elif
bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
#endif
return res;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -200,4 +200,21 @@ using bf8x32_t = bf8_t __attribute((ext_vector_type(32)));
using bf8x64_t = bf8_t __attribute((ext_vector_type(64)));
#endif
__host__ fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t vector_res;
vector_res.x = x.x + y.x;
vector_res.y = x.y + y.y;
return vector_res;
}
__device__ fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
fp16x2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(x), "v"(y));
return c;
}
} // namespace ck_tile
......@@ -2,3 +2,4 @@ add_subdirectory(image_to_column)
add_subdirectory(gemm)
add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(data_type)
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include <hip/hip_runtime.h>
#include "ck_tile/core.hpp"
using ck_tile::bf16_t;
using ck_tile::bf16x2_t;
using ck_tile::fp16x2_t;
using ck_tile::fp32x2_t;
using ck_tile::half_t;
using ck_tile::pk_int4_t;
TEST(PackedInt4, ConvertToFloat)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
constexpr float first_input_val = 7.f;
constexpr float second_input_val = -1.f;
#else
constexpr float first_input_val = -1.f;
constexpr float second_input_val = 7.f;
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
fp32x2_t out = ck_tile::pk_int4_t_to_fp32x2_t(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToHalf)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
const half_t first_input_val = ck_tile::type_convert<half_t>(7.f);
const half_t second_input_val = ck_tile::type_convert<half_t>(-1.f);
#else
const half_t first_input_val = ck_tile::type_convert<half_t>(-1.f);
const half_t second_input_val = ck_tile::type_convert<half_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
fp16x2_t out = ck_tile::pk_int4_t_to_halfx2_t(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToBHalf)
{
#ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(7.f);
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(-1.f);
#else
const bf16_t first_input_val = ck_tile::type_convert<bf16_t>(-1.f);
const bf16_t second_input_val = ck_tile::type_convert<bf16_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_int4_t in = ck_tile::bit_cast<int8_t>(data);
bf16x2_t out = ck_tile::pk_int4_t_to_bfloat16x2_t(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
......@@ -50,3 +50,4 @@ endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
add_gtest_executable(test_bhalf test_bhalf.cpp)
add_gtest_executable(test_pk_i4 test_pk_i4.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <bitset>
#include <cinttypes>
#include <cstdint>
#include <iomanip>
#include "gtest/gtest.h"
#include <hip/hip_runtime.h>
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
using ck::bhalf2_t;
using ck::bhalf_t;
using ck::float2_t;
using ck::half2_t;
using ck::half4_t;
using ck::half_t;
using ck::pk_i4_t;
using ck::pk_i4x4_t;
TEST(PackedInt4, ConvertToFloat)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
constexpr float first_input_val = 7.f;
constexpr float second_input_val = -1.f;
#else
constexpr float first_input_val = -1.f;
constexpr float second_input_val = 7.f;
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_i4_t in = ck::bit_cast<int8_t>(data);
float2_t out = ck::type_convert<float2_t>(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToHalf)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
constexpr half_t first_input_val = ck::type_convert<half_t>(7.f);
constexpr half_t second_input_val = ck::type_convert<half_t>(-1.f);
#else
constexpr half_t first_input_val = ck::type_convert<half_t>(-1.f);
constexpr half_t second_input_val = ck::type_convert<half_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_i4_t in = ck::bit_cast<int8_t>(data);
half2_t out = ck::type_convert<half2_t>(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
TEST(PackedInt4, ConvertToBHalf)
{
#ifdef CK_USE_PK4_LAYOUT_SHUFFLE
const bhalf_t first_input_val = ck::type_convert<bhalf_t>(7.f);
const bhalf_t second_input_val = ck::type_convert<bhalf_t>(-1.f);
#else
const bhalf_t first_input_val = ck::type_convert<bhalf_t>(-1.f);
const bhalf_t second_input_val = ck::type_convert<bhalf_t>(7.f);
#endif
uint8_t data = 0b11110111; // {-1, 7}
pk_i4_t in = ck::bit_cast<int8_t>(data);
bhalf2_t out = ck::type_convert<bhalf2_t>(in);
EXPECT_EQ(out.x, first_input_val);
EXPECT_EQ(out.y, second_input_val);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment