Commit 7d0eede1 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Handle X being const data type as well

parent c5ed30fa
......@@ -9,8 +9,10 @@
namespace ck {
// Convert X to Y, Y is a non-const data type.
template <typename Y, typename X, std::enable_if_t<!std::is_const_v<Y>, bool> = false>
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
......@@ -18,14 +20,17 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x);
}
// Convert X to Y, Y is a const data type.
template <typename Y, typename X, std::enable_if_t<std::is_const_v<Y>, bool> = false>
// Convert X to Y, either X or Y is a const data type..
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
return static_cast<Y>(type_convert<NonConstY>(x));
using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
// convert bfp16 to fp32
......
......@@ -15,6 +15,3 @@ if(result EQUAL 0)
endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
if(result EQUAL 0)
target_link_libraries(test_type_convert_const PRIVATE utility)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
......@@ -67,5 +67,27 @@ TEST(TypeConvertConst, ConvertFromConst)
float y_float = type_convert<float>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
// Tests with full type specializations for X.
{
// Test const float to const bhalf_t.
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
// Remove the constness manually to not rely on const casts anymore since the
// possible issue could hide after two casts.
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
float y_float = type_convert<float>(y_nonconst);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
// Test const float to non-const bhalf.
bhalf_t y = type_convert<bhalf_t, const float>(x);
float y_float = type_convert<float>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
{
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
// Test const bhalf to non-const float.
float y_float = type_convert<float, const bhalf_t>(y);
ASSERT_NEAR(y_float, x, abs_tol);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment