Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4ddb62bd
Commit
4ddb62bd
authored
May 12, 2023
by
Rostyslav Geyyer
Browse files
Add fp8_convert_sr
parent
4089bc68
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
3 deletions
+28
-3
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+28
-3
No files found.
include/ck/utility/data_type.hpp
View file @
4ddb62bd
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/get_id.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -1130,7 +1131,8 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
...
@@ -1130,7 +1131,8 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template
<
typename
T
,
uint32_t
seed
,
std
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed
,
std
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
)
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
)
{
{
uint32_t
x
=
reinterpret_cast
<
uint32_t
&>
(
val
);
// uint32_t x = reinterpret_cast<uint32_t&>(val);
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
drop_bits
^=
x
>>
16
;
drop_bits
^=
x
>>
16
;
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
...
@@ -1146,7 +1148,8 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val)
...
@@ -1146,7 +1148,8 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val)
template
<
typename
T
,
uint32_t
seed
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
)
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
)
{
{
uint16_t
x
=
reinterpret_cast
<
uint16_t
&>
(
val
);
// uint16_t x = reinterpret_cast<uint16_t&>(val);
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
drop_bits
*=
0x7000149
;
drop_bits
*=
0x7000149
;
...
@@ -1164,12 +1167,34 @@ template <typename T,
...
@@ -1164,12 +1167,34 @@ template <typename T,
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
)
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
)
{
{
std
::
ignore
=
id
;
std
::
ignore
=
id
;
std
::
ignore
=
seed
;
std
::
ignore
=
val
;
std
::
ignore
=
val
;
return
0
;
return
0
;
}
}
// Declare a template function for fp8 conversion using SR
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
fp8_convert_sr
(
X
x
);
// convert fp32 to fp8 with stochastic rounding
template
<
>
inline
__host__
__device__
f8_t
fp8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
get_thread_global_1d_id
(),
x
);
return
cast_to_f8
<
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
}
// convert fp8 to fp32
template
<
>
inline
__host__
__device__
float
fp8_convert_sr
<
float
,
f8_t
>
(
f8_t
x
)
{
return
type_convert
<
float
>
(
x
);
}
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment