Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
24608d43
Commit
24608d43
authored
Sep 25, 2024
by
Mirza Halilcevic
Browse files
Merge branch 'ck_mgx_temp' into ck_migraphx_integration
parents
a4fe62ed
eaeb3dac
Changes
49
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
220 additions
and
30 deletions
+220
-30
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+5
-1
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+1
-1
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+20
-0
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+9
-3
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+4
-0
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+7
-7
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+4
-2
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+157
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+13
-11
No files found.
include/ck/utility/loop_scheduler.hpp
View file @
24608d43
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ostream>
#pragma once
#ifndef __HIPCC_RTC__
#include <ostream>
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
...
...
@@ -26,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
}
// namespace ck
#ifndef __HIPCC_RTC__
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
LoopScheduler
&
s
)
{
switch
(
s
)
...
...
@@ -36,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
}
return
os
;
}
#endif
include/ck/utility/magic_division.hpp
View file @
24608d43
...
...
@@ -30,7 +30,7 @@ struct MagicDivision
// WARNING: magic division is only applicable for division inside this range.
// You should use the return value of CalculateMagicNumbers, if division is not inside this
// range. The "else" logic below is to quiet down run-time error.
if
(
divisor
>=
1
&&
divisor
<=
INT32_MAX
)
if
(
divisor
>=
1
&&
divisor
<=
ck
::
NumericLimits
<
int32_t
>::
Max
()
)
{
uint32_t
shift
=
0
;
for
(
shift
=
0
;
shift
<
32
;
++
shift
)
...
...
include/ck/utility/math_v2.hpp
View file @
24608d43
...
...
@@ -18,6 +18,7 @@ namespace math {
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
#ifndef __HIPCC_RTC__
// math functions for the host, some are implemented by calling C++ std functions
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
...
...
@@ -457,6 +458,7 @@ inline __host__ double expm1<double>(double x)
{
return
std
::
expm1
(
x
);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
...
...
@@ -920,5 +922,23 @@ inline __device__ double expm1<double>(double x)
return
expm1
(
x
);
};
template
<
typename
T
>
inline
__device__
T
cos
(
T
x
)
{
return
ck
::
type_convert
<
T
>
(
cosf
(
ck
::
type_convert
<
float
>
(
x
)));
};
template
<
>
inline
__device__
float
cos
<
float
>
(
float
x
)
{
return
cosf
(
x
);
};
template
<
>
inline
__device__
double
cos
<
double
>
(
double
x
)
{
return
cos
(
x
);
};
}
// namespace math
}
// namespace ck
include/ck/utility/random_gen.hpp
View file @
24608d43
...
...
@@ -7,7 +7,7 @@ namespace ck {
// Pseudo random number generator
// version for fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
ck
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
...
...
@@ -23,7 +23,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
ck
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
...
@@ -40,12 +40,18 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
ck
::
is_same
<
float
,
T
>{}
||
ck
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
#ifdef __HIPCC_RTC__
static_cast
<
void
>
(
id
);
static_cast
<
void
>
(
val
);
static_cast
<
void
>
(
seed
);
#else
std
::
ignore
=
id
;
std
::
ignore
=
val
;
std
::
ignore
=
seed
;
#endif
return
0
;
}
...
...
include/ck/utility/sequence.hpp
View file @
24608d43
...
...
@@ -3,7 +3,9 @@
#pragma once
#ifndef __HIPCC_RTC__
#include <ostream>
#endif
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
...
...
@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
}
// namespace ck
#ifndef __HIPCC_RTC__
template
<
ck
::
index_t
...
Is
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
Sequence
<
Is
...
>
)
{
...
...
@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
os
<<
S
::
At
(
S
::
Size
()
-
ck
::
Number
<
1
>
{}).
value
<<
"}"
;
return
os
;
}
#endif
include/ck/utility/tuple.hpp
View file @
24608d43
...
...
@@ -32,7 +32,7 @@ struct TupleElementKeyData
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
ck
::
forward
<
T
>
(
v
))
{
}
...
...
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
std
::
forward
(
x
.
mData
);
return
ck
::
forward
(
x
.
mData
);
}
template
<
typename
Indices
,
typename
...
Xs
>
...
...
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Y
>
(
y
))...
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Ys
>
(
ys
))...
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
...
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
ck
::
forward
<
Y
>
(
y
))
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
ck
::
forward
<
Ys
>
(
ys
)...)
{
}
...
...
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
ck
::
forward
<
Xs
>
(
xs
)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
...
...
include/ck/utility/tuple_helper.hpp
View file @
24608d43
...
...
@@ -29,7 +29,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const
Tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
...
...
@@ -38,7 +38,7 @@ template <typename... X, typename... Y>
__host__
__device__
constexpr
auto
concat_tuple
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
...
...
@@ -157,6 +157,7 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
}
}
#ifndef __HIPCC_RTC__
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
...
...
@@ -165,6 +166,7 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{
return
(
is_detected
<
is_tuple
,
Ts
>::
value
||
...);
}
#endif
template
<
index_t
depth
=
0
,
typename
T
>
__host__
__device__
constexpr
auto
TupleDepth
(
const
T
&
)
...
...
include/ck/utility/type.hpp
View file @
24608d43
...
...
@@ -8,6 +8,158 @@
#include "ck/utility/enable_if.hpp"
namespace
ck
{
#ifdef __HIPCC_RTC__
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
CK_BUILTIN_TYPE_TRAIT1
(
is_class
);
CK_BUILTIN_TYPE_TRAIT1
(
is_pointer
);
CK_BUILTIN_TYPE_TRAIT1
(
is_reference
);
CK_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
CK_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
CK_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_reference
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
volatile
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
volatile
>
{
typedef
T
type
;
};
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
// TODO
template
<
class
T
>
struct
is_const
:
false_type
{};
template
<
class
T
>
struct
is_const
<
const
T
>
:
true_type
{};
template
<
class
T
>
inline
constexpr
bool
is_const_v
=
is_const
<
T
>::
value
;
template
<
class
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
class
T
>
struct
remove_const
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_const
<
const
T
>
{
typedef
T
type
;
};
template
<
class
T
>
using
remove_const_t
=
typename
remove_const
<
T
>::
type
;
template
<
class
T
>
inline
constexpr
bool
is_class_v
=
is_class
<
T
>::
value
;
template
<
class
T
>
inline
constexpr
bool
is_trivially_copyable_v
=
is_trivially_copyable
<
T
>::
value
;
template
<
class
...
>
using
void_t
=
void
;
using
__hip
::
declval
;
#else
#include <utility>
#include <type_traits>
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_trivially_copyable
;
using
std
::
is_unsigned
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
using
std
::
is_const_v
;
using
std
::
is_reference_v
;
using
std
::
remove_const_t
;
using
std
::
is_class_v
;
using
std
::
is_trivially_copyable_v
;
using
std
::
void_t
;
using
std
::
false_type
;
using
std
::
true_type
;
using
std
::
declval
;
#endif
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
...
...
@@ -23,19 +175,19 @@ template <typename X, typename Y>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
using
remove_cvref_t
=
remove_cv_t
<
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
using
remove_pointer_t
=
typename
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
inline
constexpr
bool
is_pointer_v
=
is_pointer
<
T
>::
value
;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
...
...
include/ck/utility/type_convert.hpp
View file @
24608d43
...
...
@@ -17,10 +17,10 @@ namespace ck {
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
...
...
@@ -28,13 +28,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
ck
::
enable_if_t
<
ck
::
is_const_v
<
Y
>
||
ck
::
is_const_v
<
X
>
,
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
using
NonConstY
=
std
::
remove_const_t
<
Y
>
;
using
NonConstX
=
std
::
remove_const_t
<
X
>
;
using
NonConstY
=
ck
::
remove_const_t
<
Y
>
;
using
NonConstX
=
ck
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
NonConstY
,
NonConstX
>
(
x
));
}
...
...
@@ -104,7 +104,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert_sp
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
...
...
@@ -166,7 +166,7 @@ template <>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
long_index
_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
union
{
...
...
@@ -206,7 +206,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
long_index
_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
...
...
@@ -218,7 +218,7 @@ template <>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
long_index
_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
union
{
...
...
@@ -258,7 +258,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
long_index
_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
...
...
@@ -501,6 +501,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
}
#ifndef __HIPCC_RTC__
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
...
...
@@ -510,6 +511,7 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
#endif
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment