Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
562ec121
Commit
562ec121
authored
Jun 09, 2023
by
Rostyslav Geyyer
Browse files
Rearrange f8_utils' namespaces
parent
c208a8ac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
20 deletions
+30
-20
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+20
-12
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+10
-8
No files found.
include/ck/utility/f8_utils.hpp
View file @
562ec121
...
@@ -17,6 +17,12 @@ enum class f8_rounding_mode
...
@@ -17,6 +17,12 @@ enum class f8_rounding_mode
stochastic
stochastic
};
};
}
// namespace ck
namespace
ck
::
utils
{
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
{
{
...
@@ -127,17 +133,6 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
...
@@ -127,17 +133,6 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return
(
sign
<<
(
f8_exp
+
f8_mant
))
|
(
exponent
<<
f8_mant
)
|
mantissa
;
return
(
sign
<<
(
f8_exp
+
f8_mant
))
|
(
exponent
<<
f8_mant
)
|
mantissa
;
}
}
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted to f8."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
T
,
bool
negative_zero_nan
>
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
{
{
...
@@ -225,6 +220,19 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
...
@@ -225,6 +220,19 @@ __host__ __device__ T run_cast_from_f8(f8_t x)
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
}
}
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
{
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted to f8."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
template
<
typename
T
,
bool
negative_zero_nan
>
template
<
typename
T
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
{
{
...
@@ -240,4 +248,4 @@ __host__ __device__ T cast_from_f8(f8_t x)
...
@@ -240,4 +248,4 @@ __host__ __device__ T cast_from_f8(f8_t x)
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
}
}
}
// namespace ck
}
// namespace ck
::utils
include/ck/utility/type_convert.hpp
View file @
562ec121
...
@@ -111,7 +111,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...
@@ -111,7 +111,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
}
// convert fp8 to fp32
// convert fp8 to fp32
...
@@ -119,7 +120,7 @@ template <>
...
@@ -119,7 +120,7 @@ template <>
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
inline
__host__
__device__
float
type_convert
<
float
,
f8_t
>
(
f8_t
x
)
{
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
cast_from_f8
<
float
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
float
,
negative_zero_nan
>
(
x
);
}
}
// convert fp16 to fp8
// convert fp16 to fp8
...
@@ -130,8 +131,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
...
@@ -130,8 +131,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
constexpr
bool
clip
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
standard
;
constexpr
uint32_t
rng
=
0
;
constexpr
uint32_t
rng
=
0
;
return
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
return
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
rng
);
x
,
rng
);
}
}
// convert fp8 to fp16
// convert fp8 to fp16
...
@@ -139,7 +140,7 @@ template <>
...
@@ -139,7 +140,7 @@ template <>
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_t
>
(
f8_t
x
)
inline
__host__
__device__
half_t
type_convert
<
half_t
,
f8_t
>
(
f8_t
x
)
{
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
negative_zero_nan
=
true
;
return
cast_from_f8
<
half_t
,
negative_zero_nan
>
(
x
);
return
utils
::
cast_from_f8
<
half_t
,
negative_zero_nan
>
(
x
);
}
}
// Declare a template function for bf16 conversion using RTN
// Declare a template function for bf16 conversion using RTN
...
@@ -213,7 +214,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...
@@ -213,7 +214,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
return
utils
::
cast_to_f8
<
float
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
}
// convert fp16 to fp8 with stochastic rounding
// convert fp16 to fp8 with stochastic rounding
...
@@ -226,8 +228,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
...
@@ -226,8 +228,8 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
int
seed
=
42
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
return
utils
::
cast_to_f8
<
half_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
rng
);
x
,
rng
);
}
}
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment