Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0d8e489b
"...composable_kernel.git" did not exist on "7e8e54dead53f981b7e562c1eed6e80cb8702aac"
Commit
0d8e489b
authored
Jul 12, 2023
by
Rostyslav Geyyer
Browse files
Format
parent
74d97e51
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
38 deletions
+48
-38
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+33
-25
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+1
-4
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+12
-8
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+2
-1
No files found.
include/ck/utility/data_type.hpp
View file @
0d8e489b
...
@@ -13,12 +13,12 @@ using half_t = _Float16;
...
@@ -13,12 +13,12 @@ using half_t = _Float16;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
#endif
#endif
struct
f8_t
struct
f8_t
{
{
uint8_t
data
;
uint8_t
data
;
using
type
=
f8_t
;
using
type
=
f8_t
;
using
data_type
=
uint8_t
;
using
data_type
=
uint8_t
;
__host__
__device__
f8_t
()
=
default
;
__host__
__device__
f8_t
()
=
default
;
__host__
__device__
f8_t
(
uint8_t
init
);
__host__
__device__
f8_t
(
uint8_t
init
);
};
};
...
@@ -26,7 +26,7 @@ struct f8_t
...
@@ -26,7 +26,7 @@ struct f8_t
struct
bf8_t
struct
bf8_t
{
{
uint8_t
data
;
uint8_t
data
;
using
type
=
bf8_t
;
using
type
=
bf8_t
;
using
data_type
=
uint8_t
;
using
data_type
=
uint8_t
;
__host__
__device__
bf8_t
()
=
default
;
__host__
__device__
bf8_t
()
=
default
;
...
@@ -205,7 +205,8 @@ struct vector_type<T, 1>
...
@@ -205,7 +205,8 @@ struct vector_type<T, 1>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
struct
vector_type
<
T
,
2
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -257,7 +258,8 @@ struct vector_type<T, 2>
...
@@ -257,7 +258,8 @@ struct vector_type<T, 2>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
4
>
struct
vector_type
<
T
,
4
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -321,7 +323,8 @@ struct vector_type<T, 4>
...
@@ -321,7 +323,8 @@ struct vector_type<T, 4>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
8
>
struct
vector_type
<
T
,
8
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -397,7 +400,8 @@ struct vector_type<T, 8>
...
@@ -397,7 +400,8 @@ struct vector_type<T, 8>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
16
>
struct
vector_type
<
T
,
16
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -485,7 +489,8 @@ struct vector_type<T, 16>
...
@@ -485,7 +489,8 @@ struct vector_type<T, 16>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
32
>
struct
vector_type
<
T
,
32
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -583,7 +588,8 @@ struct vector_type<T, 32>
...
@@ -583,7 +588,8 @@ struct vector_type<T, 32>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
64
>
struct
vector_type
<
T
,
64
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -693,7 +699,8 @@ struct vector_type<T, 64>
...
@@ -693,7 +699,8 @@ struct vector_type<T, 64>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
128
>
struct
vector_type
<
T
,
128
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -813,7 +820,8 @@ struct vector_type<T, 128>
...
@@ -813,7 +820,8 @@ struct vector_type<T, 128>
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
256
>
struct
vector_type
<
T
,
256
>
{
{
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
T_adjusted
=
typename
std
::
conditional
<
std
::
is_same
<
T
,
f8_t
>::
value
,
f8_t
::
data_type
,
T
>::
type
;
using
d1_t
=
T_adjusted
;
using
d1_t
=
T_adjusted
;
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
typedef
T_adjusted
d2_t
__attribute__
((
ext_vector_type
(
2
)));
...
@@ -1047,33 +1055,33 @@ struct NumericLimits<f8_t>
...
@@ -1047,33 +1055,33 @@ struct NumericLimits<f8_t>
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
__host__
__device__
static
f8_t
Min
()
__host__
__device__
static
f8_t
Min
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_min
;
x
.
data
=
binary_min
;
return
x
;
return
x
;
}
}
__host__
__device__
static
f8_t
Max
()
__host__
__device__
static
f8_t
Max
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_max
;
x
.
data
=
binary_max
;
return
x
;
return
x
;
}
}
__host__
__device__
static
f8_t
Lowest
()
__host__
__device__
static
f8_t
Lowest
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_lowest
;
x
.
data
=
binary_lowest
;
return
x
;
return
x
;
}
}
__host__
__device__
static
f8_t
QuietNaN
()
__host__
__device__
static
f8_t
QuietNaN
()
{
{
f8_t
x
;
f8_t
x
;
x
.
data
=
binary_qnan
;
x
.
data
=
binary_qnan
;
return
x
;
return
x
;
}
}
};
};
...
...
include/ck/utility/f8_utils.hpp
View file @
0d8e489b
...
@@ -250,7 +250,4 @@ __host__ __device__ T cast_from_f8(uint8_t x)
...
@@ -250,7 +250,4 @@ __host__ __device__ T cast_from_f8(uint8_t x)
}
// namespace ck::utils
}
// namespace ck::utils
// f8_t constuctor impl
// f8_t constuctor impl
inline
__host__
__device__
ck
::
f8_t
::
f8_t
(
uint8_t
init
)
inline
__host__
__device__
ck
::
f8_t
::
f8_t
(
uint8_t
init
)
{
data
=
init
;
}
{
data
=
init
;
}
include/ck/utility/type_convert.hpp
View file @
0d8e489b
...
@@ -106,8 +106,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...
@@ -106,8 +106,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
x
,
rng
));
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
}
}
// convert fp8 to fp32
// convert fp8 to fp32
...
@@ -126,8 +127,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -126,8 +127,9 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
x
,
rng
));
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
}
}
// convert fp8 to fp16
// convert fp8 to fp16
...
@@ -209,8 +211,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -209,8 +211,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
f8_t
(
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
x
,
rng
));
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
}
}
// convert fp16 to fp8 with stochastic rounding
// convert fp16 to fp8 with stochastic rounding
...
@@ -223,8 +226,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -223,8 +226,9 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
f8_t
(
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
return
f8_t
(
x
,
rng
));
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
));
}
}
}
// namespace ck
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
0d8e489b
...
@@ -216,7 +216,8 @@ check_err(const Range& out,
...
@@ -216,7 +216,8 @@ check_err(const Range& out,
template
<
typename
Range
,
typename
RefRange
>
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
bool
>
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f8_t
>
),
bool
>
check_err
(
const
Range
&
out
,
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment