Unverified Commit f4af5aed authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski Committed by GitHub
Browse files

Handle type conversions to a const datatype (#944)

* Handle type conversions to a const datatype

* Review: Handle X being const data type as well

* Review: Remove typo
parent e2243a4d
...@@ -9,8 +9,10 @@ ...@@ -9,8 +9,10 @@
namespace ck { namespace ck {
// Convert X to Y // Convert X to Y, both X and Y are non-const data types.
template <typename Y, typename X> template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
...@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
......
...@@ -13,3 +13,5 @@ add_gtest_executable(test_bf8 bf8.cpp) ...@@ -13,3 +13,5 @@ add_gtest_executable(test_bf8 bf8.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility) target_link_libraries(test_bf8 PRIVATE utility)
endif() endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bhalf_t;
using ck::type_convert;
TEST(TypeConvertConst, ConvertToConst)
{
constexpr float bf16_epsilon = 0.0078125;
constexpr float rel_tol = 2 * bf16_epsilon;
const std::vector<float> cases = {0.0, -123.f, 3.981323f, 0.2429f};
for(float x : cases)
{
const float abs_tol = std::abs(rel_tol * x);
{
bhalf_t y = type_convert<bhalf_t>(x);
// Test non-const bhalf to const float.
const float y_float = type_convert<const float>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
// Test non-const float to const bhalf.
const bhalf_t y = type_convert<const bhalf_t>(x);
// Remove the constness manually to not rely on const casts anymore since the
// possible issue could hide after two casts.
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
float y_float = type_convert<float>(y_nonconst);
ASSERT_NEAR(y_float, x, abs_tol);
}
}
}
TEST(TypeConvertConst, ConvertFromConst)
{
constexpr float bf16_epsilon = 0.0078125;
constexpr float rel_tol = 2 * bf16_epsilon;
const std::vector<float> cases = {0.0, -123.f, 3.981323f, 0.2429f};
for(const float x : cases)
{
const float abs_tol = std::abs(rel_tol * x);
{
// Test const float to const bhalf_t.
const bhalf_t y = type_convert<const bhalf_t>(x);
// Remove the constness manually to not rely on const casts anymore since the
// possible issue could hide after two casts.
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
float y_float = type_convert<float>(y_nonconst);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
// Test const float to non-const bhalf.
bhalf_t y = type_convert<bhalf_t>(x);
float y_float = type_convert<float>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
const bhalf_t y = type_convert<const bhalf_t>(x);
// Test const bhalf to non-const float.
float y_float = type_convert<float>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
// Tests with full type specializations for X.
{
// Test const float to const bhalf_t.
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
// Remove the constness manually to not rely on const casts anymore since the
// possible issue could hide after two casts.
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
float y_float = type_convert<float>(y_nonconst);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
// Test const float to non-const bhalf.
bhalf_t y = type_convert<bhalf_t, const float>(x);
float y_float = type_convert<float>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
// Test const bhalf to non-const float.
float y_float = type_convert<float, const bhalf_t>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment