Commit 7d0eede1 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Handle X being const data type as well

parent c5ed30fa
...@@ -9,8 +9,10 @@ ...@@ -9,8 +9,10 @@
namespace ck { namespace ck {
// Convert X to Y, Y is a non-const data type. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, typename X, std::enable_if_t<!std::is_const_v<Y>, bool> = false> template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
...@@ -18,14 +20,17 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -18,14 +20,17 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// Convert X to Y, Y is a const data type. // Convert X to Y, either X or Y is a const data type..
template <typename Y, typename X, std::enable_if_t<std::is_const_v<Y>, bool> = false> template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>; using NonConstY = std::remove_const_t<Y>;
return static_cast<Y>(type_convert<NonConstY>(x)); using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
} }
// convert bfp16 to fp32 // convert bfp16 to fp32
......
...@@ -15,6 +15,3 @@ if(result EQUAL 0) ...@@ -15,6 +15,3 @@ if(result EQUAL 0)
endif() endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp) add_gtest_executable(test_type_convert_const type_convert_const.cpp)
if(result EQUAL 0)
target_link_libraries(test_type_convert_const PRIVATE utility)
endif()
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
...@@ -67,5 +67,27 @@ TEST(TypeConvertConst, ConvertFromConst) ...@@ -67,5 +67,27 @@ TEST(TypeConvertConst, ConvertFromConst)
float y_float = type_convert<float>(y); float y_float = type_convert<float>(y);
ASSERT_NEAR(y_float, x, abs_tol); ASSERT_NEAR(y_float, x, abs_tol);
} }
// Tests with full type specializations for X.
{
// Test const float to const bhalf_t.
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
// Remove the constness manually to not rely on const casts anymore since the
// possible issue could hide after two casts.
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
float y_float = type_convert<float>(y_nonconst);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
// Test const float to non-const bhalf.
bhalf_t y = type_convert<bhalf_t, const float>(x);
float y_float = type_convert<float>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
// Test const bhalf to non-const float.
float y_float = type_convert<float, const bhalf_t>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment