Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c5ed30fa
Commit
c5ed30fa
authored
Sep 26, 2023
by
Bartlomiej Wroblewski
Browse files
Handle type conversions to a const datatype
parent
c9553832
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
2 deletions
+88
-2
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+12
-2
test/data_type/CMakeLists.txt
test/data_type/CMakeLists.txt
+5
-0
test/data_type/type_convert_const.cpp
test/data_type/type_convert_const.cpp
+71
-0
No files found.
include/ck/utility/type_convert.hpp
View file @
c5ed30fa
...
...
@@ -9,8 +9,8 @@
namespace
ck
{
// Convert X to Y
template
<
typename
Y
,
typename
X
>
// Convert X to Y
, Y is a non-const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
std
::
is_const_v
<
Y
>,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
...
...
@@ -18,6 +18,16 @@ __host__ __device__ constexpr Y type_convert(X x)
return
static_cast
<
Y
>
(
x
);
}
// Convert X to Y, Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
>
(
x
));
}
// convert bfp16 to fp32
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
...
...
test/data_type/CMakeLists.txt
View file @
c5ed30fa
...
...
@@ -13,3 +13,8 @@ add_gtest_executable(test_bf8 bf8.cpp)
if
(
result EQUAL 0
)
target_link_libraries
(
test_bf8 PRIVATE utility
)
endif
()
add_gtest_executable
(
test_type_convert_const type_convert_const.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_type_convert_const PRIVATE utility
)
endif
()
test/data_type/type_convert_const.cpp
0 → 100644
View file @
c5ed30fa
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using
ck
::
bhalf_t
;
using
ck
::
type_convert
;
TEST
(
TypeConvertConst
,
ConvertToConst
)
{
constexpr
float
bf16_epsilon
=
0.0078125
;
constexpr
float
rel_tol
=
2
*
bf16_epsilon
;
const
std
::
vector
<
float
>
cases
=
{
0.0
,
-
123.
f
,
3.981323
f
,
0.2429
f
};
for
(
float
x
:
cases
)
{
const
float
abs_tol
=
std
::
abs
(
rel_tol
*
x
);
{
bhalf_t
y
=
type_convert
<
bhalf_t
>
(
x
);
// Test non-const bhalf to const float.
const
float
y_float
=
type_convert
<
const
float
>
(
y
);
ASSERT_NEAR
(
y_float
,
x
,
abs_tol
);
}
{
// Test non-const float to const bhalf.
const
bhalf_t
y
=
type_convert
<
const
bhalf_t
>
(
x
);
// Remove the constness manually to not rely on const casts anymore since the
// possible issue could hide after two casts.
bhalf_t
&
y_nonconst
=
const_cast
<
bhalf_t
&>
(
y
);
float
y_float
=
type_convert
<
float
>
(
y_nonconst
);
ASSERT_NEAR
(
y_float
,
x
,
abs_tol
);
}
}
}
TEST
(
TypeConvertConst
,
ConvertFromConst
)
{
constexpr
float
bf16_epsilon
=
0.0078125
;
constexpr
float
rel_tol
=
2
*
bf16_epsilon
;
const
std
::
vector
<
float
>
cases
=
{
0.0
,
-
123.
f
,
3.981323
f
,
0.2429
f
};
for
(
const
float
x
:
cases
)
{
const
float
abs_tol
=
std
::
abs
(
rel_tol
*
x
);
{
// Test const float to const bhalf_t.
const
bhalf_t
y
=
type_convert
<
const
bhalf_t
>
(
x
);
// Remove the constness manually to not rely on const casts anymore since the
// possible issue could hide after two casts.
bhalf_t
&
y_nonconst
=
const_cast
<
bhalf_t
&>
(
y
);
float
y_float
=
type_convert
<
float
>
(
y_nonconst
);
ASSERT_NEAR
(
y_float
,
x
,
abs_tol
);
}
{
// Test const float to non-const bhalf.
bhalf_t
y
=
type_convert
<
bhalf_t
>
(
x
);
float
y_float
=
type_convert
<
float
>
(
y
);
ASSERT_NEAR
(
y_float
,
x
,
abs_tol
);
}
{
const
bhalf_t
y
=
type_convert
<
const
bhalf_t
>
(
x
);
// Test const bhalf to non-const float.
float
y_float
=
type_convert
<
float
>
(
y
);
ASSERT_NEAR
(
y_float
,
x
,
abs_tol
);
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment