Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0d8e489b
Commit
0d8e489b
authored
Jul 12, 2023
by
Rostyslav Geyyer
Browse files
Format
parent
74d97e51
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
38 deletions
+48
-38
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+33
-25
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+1
-4
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+12
-8
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+2
-1
No files found.
include/ck/utility/data_type.hpp
View file @
0d8e489b
...
@@ -205,7 +205,8 @@ struct vector_type<T, 1>
...
@@ -205,7 +205,8 @@ struct vector_type<T, 1>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -257,7 +258,8 @@ struct vector_type<T, 2>
...
@@ -257,7 +258,8 @@ struct vector_type<T, 2>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
4
>
struct
vector_type
<
T
,
4
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -321,7 +323,8 @@ struct vector_type<T, 4>
...
@@ -321,7 +323,8 @@ struct vector_type<T, 4>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
8
>
struct
vector_type
<
T
,
8
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -397,7 +400,8 @@ struct vector_type<T, 8>
...
@@ -397,7 +400,8 @@ struct vector_type<T, 8>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
16
>
struct
vector_type
<
T
,
16
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -485,7 +489,8 @@ struct vector_type<T, 16>
...
@@ -485,7 +489,8 @@ struct vector_type<T, 16>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
32
>
struct
vector_type
<
T
,
32
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -583,7 +588,8 @@ struct vector_type<T, 32>
...
@@ -583,7 +588,8 @@ struct vector_type<T, 32>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
64
>
struct
vector_type
<
T
,
64
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -693,7 +699,8 @@ struct vector_type<T, 64>
...
@@ -693,7 +699,8 @@ struct vector_type<T, 64>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
128
>
struct
vector_type
<
T
,
128
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -813,7 +820,8 @@ struct vector_type<T, 128>
...
@@ -813,7 +820,8 @@ struct vector_type<T, 128>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
256
>
struct
vector_type
<
T
,
256
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -1051,28 +1059,28 @@ struct NumericLimits<f8_t>
...
@@ -1051,28 +1059,28 @@ struct NumericLimits<f8_t>
__host__
__device__
static
f8_t
Min
()
__host__
__device__
static
f8_t
Min
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_min
;
x
.
data
=
binary_min
;
return
x
;
return
x
;
}
}
__host__
__device__
static
f8_t
Max
()
__host__
__device__
static
f8_t
Max
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_max
;
x
.
data
=
binary_max
;
return
x
;
return
x
;
}
}
__host__
__device__
static
f8_t
Lowest
()
__host__
__device__
static
f8_t
Lowest
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_lowest
;
x
.
data
=
binary_lowest
;
return
x
;
return
x
;
}
}
__host__
__device__
static
f8_t
QuietNaN
()
__host__
__device__
static
f8_t
QuietNaN
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_qnan
;
x
.
data
=
binary_qnan
;
return
x
;
return
x
;
}
}
};
};
...
...
include/ck/utility/f8_utils.hpp
View file @
0d8e489b
...
@@ -250,7 +250,4 @@ __host__ __device__ T cast_from_f8(uint8_t x)
...
@@ -250,7 +250,4 @@ __host__ __device__ T cast_from_f8(uint8_t x)
}
// namespace ck::utils
}
// namespace ck::utils
// f8_t constuctor impl
// f8_t constuctor impl
inline
__host__
__device__
ck
::
f8_t
::
f8_t
(
uint8_t
init
)
inline
__host__
__device__
ck
::
f8_t
::
f8_t
(
uint8_t
init
)
{
data
=
init
;
}
{
data
=
init
;
}
include/ck/utility/type_convert.hpp
View file @
0d8e489b
...
@@ -106,7 +106,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...
@@ -106,7 +106,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
x
,
rng
));
}
}
...
@@ -126,7 +127,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -126,7 +127,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
x
,
rng
));
}
}
...
@@ -209,7 +211,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -209,7 +211,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
x
,
rng
));
}
}
...
@@ -223,7 +226,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -223,7 +226,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
x
,
rng
));
}
}
...
...
library/include/ck/library/utility/check_err.hpp
View file @
0d8e489b
...
@@ -216,7 +216,8 @@ check_err(const Range& out,
...
@@ -216,7 +216,8 @@ check_err(const Range& out,
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
bool
>
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
bool
>
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment