Commit 9051a9f9 authored by Adam Osewski's avatar Adam Osewski
Browse files

Add UT with static_cast

parent 0cf8b552
...@@ -68,6 +68,19 @@ __global__ void copy(const int4_t* src, std::int8_t* dst, ck::index_t N) ...@@ -68,6 +68,19 @@ __global__ void copy(const int4_t* src, std::int8_t* dst, ck::index_t N)
} }
} }
__global__ void copy_with_static_cast(const int4_t* src, std::int8_t* dst, ck::index_t N)
{
ck::index_t tid = ck::get_thread_global_1d_id();
if(tid < N)
{
for(ck::index_t i = tid; i < N; i += ck::get_grid_size())
{
dst[i] = static_cast<std::int8_t>(src[i]);
}
}
}
} // anonymous namespace } // anonymous namespace
TEST(Int4, CopyAsI8PositiveValue) TEST(Int4, CopyAsI8PositiveValue)
...@@ -124,6 +137,33 @@ TEST(Int4, DISABLED_CopyAsI8NegativeValue) ...@@ -124,6 +137,33 @@ TEST(Int4, DISABLED_CopyAsI8NegativeValue)
} }
} }
TEST(Int4, CopyAsI8NegativeValueStaticCast)
{
constexpr std::size_t SIZE = 32;
std::vector<int4_t> h_src_i4(SIZE, -8);
std::vector<std::int8_t> h_src_i8(SIZE, -8);
std::vector<std::int8_t> h_dst_i8(SIZE, 0);
DeviceMem d_src_i4(h_src_i4.size() * sizeof(int4_t));
DeviceMem d_dst_i8(h_dst_i8.size() * sizeof(std::int8_t));
d_src_i4.SetZero();
d_dst_i8.SetZero();
d_src_i4.ToDevice(h_src_i4.data());
copy_with_static_cast<<<1, 64>>>(reinterpret_cast<const int4_t*>(d_src_i4.GetDeviceBuffer()),
reinterpret_cast<std::int8_t*>(d_dst_i8.GetDeviceBuffer()),
SIZE);
hip_check_error(hipDeviceSynchronize());
d_dst_i8.FromDevice(h_dst_i8.data());
for(std::size_t i = 0; i < SIZE; ++i)
{
EXPECT_EQ(h_src_i8[i], h_dst_i8[i]);
}
}
TEST(Int4, DISABLED_BitwiseRepresentation) TEST(Int4, DISABLED_BitwiseRepresentation)
{ {
using bit8_t = std::bitset<8>; using bit8_t = std::bitset<8>;
...@@ -145,3 +185,25 @@ TEST(Int4, DISABLED_BitwiseRepresentation) ...@@ -145,3 +185,25 @@ TEST(Int4, DISABLED_BitwiseRepresentation)
EXPECT_EQ(bit8_t{static_cast<std::uint64_t>(a_i8)}, bit8_t{static_cast<std::uint64_t>(b_i8)}); EXPECT_EQ(bit8_t{static_cast<std::uint64_t>(a_i8)}, bit8_t{static_cast<std::uint64_t>(b_i8)});
} }
TEST(Int4, BitwiseRepresentationStaticCast)
{
using bit8_t = std::bitset<8>;
int4_t a_i4{3};
std::int8_t a_i8 = static_cast<std::int8_t>(a_i4);
std::int8_t b_i8{3};
// std::cout << std::hex << std::showbase << static_cast<int32_t>(a_i8)
// << ", " << static_cast<int32_t>(b_i8) << std::endl;
EXPECT_EQ(bit8_t{static_cast<std::uint64_t>(a_i8)}, bit8_t{static_cast<std::uint64_t>(b_i8)});
a_i4 = int4_t{-3};
a_i8 = static_cast<std::int8_t>(a_i4);
b_i8 = std::int8_t{-3};
// std::cout << std::hex << std::showbase << static_cast<int32_t>(a_i8)
// << ", " << static_cast<int32_t>(b_i8) << std::endl;
EXPECT_EQ(bit8_t{static_cast<std::uint64_t>(a_i8)}, bit8_t{static_cast<std::uint64_t>(b_i8)});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment