Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
24961297
Commit
24961297
authored
Sep 24, 2024
by
Astha Rai
Browse files
updating multiple utility files to deal with standard header inclusion for hiprtc
parent
9cdd9165
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
58 additions
and
34 deletions
+58
-34
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+1
-1
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+2
-2
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+1
-0
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+17
-0
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+2
-2
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+5
-5
include/ck/utility/integral_constant.hpp
include/ck/utility/integral_constant.hpp
+5
-0
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-2
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+7
-6
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+7
-7
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+2
-2
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+7
-7
No files found.
include/ck/utility/array.hpp
View file @
24961297
...
@@ -54,7 +54,7 @@ template <typename X, typename... Xs>
...
@@ -54,7 +54,7 @@ template <typename X, typename... Xs>
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
__host__
__device__
constexpr
auto
make_array
(
X
&&
x
,
Xs
&&
...
xs
)
{
{
using
data_type
=
remove_cvref_t
<
X
>
;
using
data_type
=
remove_cvref_t
<
X
>
;
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Xs
>
(
xs
)...};
return
Array
<
data_type
,
sizeof
...(
Xs
)
+
1
>
{
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Xs
>
(
xs
)...};
}
}
// make empty array
// make empty array
...
...
include/ck/utility/container_helper.hpp
View file @
24961297
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
}
template
<
typename
...
X
,
typename
...
Y
>
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
}
template
<
typename
Container
>
template
<
typename
Container
>
...
...
include/ck/utility/debug.hpp
View file @
24961297
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#ifndef UTILITY_DEBUG_HPP
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
ck
{
namespace
debug
{
namespace
debug
{
...
...
include/ck/utility/enable_if.hpp
View file @
24961297
...
@@ -5,10 +5,27 @@
...
@@ -5,10 +5,27 @@
namespace
ck
{
namespace
ck
{
#ifndef CK_CODE_GEN_RTC
template
<
bool
B
,
typename
T
=
void
>
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
#else
template
<
bool
B
,
class
T
=
void
>
struct
enable_if
{
};
template
<
class
T
>
struct
enable_if
<
true
,
T
>
{
using
type
=
T
;
};
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
#endif
}
// namespace ck
}
// namespace ck
include/ck/utility/functional.hpp
View file @
24961297
...
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
...
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
{
if
constexpr
(
predicate
)
if
constexpr
(
predicate
)
{
{
return
std
::
forward
<
X
>
(
x
);
return
ck
::
forward
<
X
>
(
x
);
}
}
else
else
{
{
return
std
::
forward
<
Y
>
(
y
);
return
ck
::
forward
<
Y
>
(
y
);
}
}
}
}
...
...
include/ck/utility/functional4.hpp
View file @
24961297
...
@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
...
@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template
<
typename
F
,
typename
X
>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
{
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
}
}
};
};
...
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
...
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template
<
typename
F
,
typename
X
,
typename
Y
>
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
{
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
std
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
ck
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
}
}
};
};
...
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
...
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
{
using
X_
=
remove_reference_t
<
X
>
;
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
>
{}(
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
));
}
}
// TODO: properly implement unpack that takes any number of containers
// TODO: properly implement unpack that takes any number of containers
...
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
...
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using
Y_
=
remove_reference_t
<
Y
>
;
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
,
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
Size
(),
1
>::
type
>
{}(
typename
arithmetic_sequence_gen
<
0
,
Y_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Y
>
(
y
));
}
}
}
// namespace ck
}
// namespace ck
...
...
include/ck/utility/integral_constant.hpp
View file @
24961297
...
@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
...
@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
}
}
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
}
// namespace ck
}
// namespace ck
include/ck/utility/math_v2.hpp
View file @
24961297
...
@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
...
@@ -19,7 +19,7 @@ extern "C" __device__ float __ocml_native_recip_f32(float);
#endif
#endif
// math functions for the host, some are implemented by calling C++ std functions
// math functions for the host, some are implemented by calling C++ std functions
#ifndef CK_CODE_GEN_RTC
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
...
@@ -457,7 +457,7 @@ inline __host__ double expm1<double>(double x)
...
@@ -457,7 +457,7 @@ inline __host__ double expm1<double>(double x)
{
{
return
std
::
expm1
(
x
);
return
std
::
expm1
(
x
);
}
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
...
...
include/ck/utility/random_gen.hpp
View file @
24961297
...
@@ -2,12 +2,13 @@
...
@@ -2,12 +2,13 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <ck/utility/ignore.hpp>
namespace
ck
{
namespace
ck
{
// Pseudo random number generator
// Pseudo random number generator
// version for fp32
// version for fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
...
@@ -23,7 +24,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
...
@@ -23,7 +24,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
}
// version for fp16
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
@@ -40,12 +41,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
...
@@ -40,12 +41,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
template
<
typename
T
,
uint32_t
seed_t
,
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
{
std
::
ignore
=
id
;
ck
::
ignore
=
id
;
std
::
ignore
=
val
;
ck
::
ignore
=
val
;
std
::
ignore
=
seed
;
ck
::
ignore
=
seed
;
return
0
;
return
0
;
}
}
...
...
include/ck/utility/tuple.hpp
View file @
24961297
...
@@ -32,7 +32,7 @@ struct TupleElementKeyData
...
@@ -32,7 +32,7 @@ struct TupleElementKeyData
template
<
typename
T
,
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
ck
::
forward
<
T
>
(
v
))
{
{
}
}
...
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
...
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template
<
typename
Key
,
typename
Data
>
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
Data
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
{
return
std
::
forward
(
x
.
mData
);
return
ck
::
forward
(
x
.
mData
);
}
}
template
<
typename
Indices
,
typename
...
Xs
>
template
<
typename
Indices
,
typename
...
Xs
>
...
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
...
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Y
>
(
y
))...
{
{
}
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Ys
>
(
ys
))...
{
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
"wrong! inconsistent size"
);
...
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
Y
,
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
ck
::
forward
<
Y
>
(
y
))
{
{
}
}
template
<
typename
...
Ys
,
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
typename
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
false
>
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
ck
::
forward
<
Ys
>
(
ys
)...)
{
{
}
}
...
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
...
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template
<
typename
...
Xs
>
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
{
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
ck
::
forward
<
Xs
>
(
xs
)...);
}
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
// https://en.cppreference.com/w/cpp/utility/tuple/tie
...
...
include/ck/utility/tuple_helper.hpp
View file @
24961297
...
@@ -29,7 +29,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
...
@@ -29,7 +29,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const
Tuple
<
Y
&
...
>&
ty
)
const
Tuple
<
Y
&
...
>&
ty
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
tx
,
ty
);
ty
);
}
}
...
@@ -38,7 +38,7 @@ template <typename... X, typename... Y>
...
@@ -38,7 +38,7 @@ template <typename... X, typename... Y>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
{
return
unpack2
(
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
tx
,
ty
);
ty
);
}
}
...
...
include/ck/utility/type_convert.hpp
View file @
24961297
...
@@ -20,7 +20,7 @@ template <typename Y,
...
@@ -20,7 +20,7 @@ template <typename Y,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
...
@@ -31,7 +31,7 @@ template <typename Y,
...
@@ -31,7 +31,7 @@ template <typename Y,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
...
@@ -104,7 +104,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
...
@@ -104,7 +104,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
{
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
return
static_cast
<
Y
>
(
x
);
}
}
...
@@ -166,7 +166,7 @@ template <>
...
@@ -166,7 +166,7 @@ template <>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size
_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -206,7 +206,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -206,7 +206,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size
_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
...
@@ -218,7 +218,7 @@ template <>
...
@@ -218,7 +218,7 @@ template <>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
{
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size
_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
#if defined(__gfx94__)
union
union
{
{
...
@@ -258,7 +258,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
...
@@ -258,7 +258,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size
_t
>
(
&
x
),
x
);
return
utils
::
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
x
,
rng
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment