Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
807a4818
Commit
807a4818
authored
Oct 21, 2024
by
Andriy Roshchenko
Browse files
Add constexpr where applicable.
parent
b1a7d2a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
20 deletions
+21
-20
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+21
-20
No files found.
include/ck/utility/amd_ck_fp8.hpp
View file @
807a4818
...
@@ -111,7 +111,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -111,7 +111,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
fmax
=
bit_cast
<
_Float16
>
(
ifmax
);
fmax
=
bit_cast
<
_Float16
>
(
ifmax
);
fmin
=
bit_cast
<
_Float16
>
(
ifmin
);
fmin
=
bit_cast
<
_Float16
>
(
ifmin
);
}
}
else
if
(
is_float
)
else
if
constexpr
(
is_float
)
{
{
const
unsigned
int
ifInf
=
0x7F800000
;
const
unsigned
int
ifInf
=
0x7F800000
;
const
unsigned
int
ifNegInf
=
0xFF800000
;
const
unsigned
int
ifNegInf
=
0xFF800000
;
...
@@ -128,7 +128,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -128,7 +128,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
fmax
=
bit_cast
<
float
>
(
ifmax
);
fmax
=
bit_cast
<
float
>
(
ifmax
);
fmin
=
bit_cast
<
float
>
(
ifmin
);
fmin
=
bit_cast
<
float
>
(
ifmin
);
}
}
else
if
(
is_double
)
else
if
constexpr
(
is_double
)
{
{
const
unsigned
long
long
ifInf
=
0x7FF0000000000000ull
;
const
unsigned
long
long
ifInf
=
0x7FF0000000000000ull
;
const
unsigned
long
long
ifNegInf
=
0xFFF0000000000000ull
;
const
unsigned
long
long
ifNegInf
=
0xFFF0000000000000ull
;
...
@@ -167,7 +167,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -167,7 +167,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{
{
return
fNeg0
;
return
fNeg0
;
}
}
if
(
we
==
4
)
if
constexpr
(
we
==
4
)
{
// e4m3
{
// e4m3
if
((
x
&
0x7F
)
==
0x7F
)
if
((
x
&
0x7F
)
==
0x7F
)
{
{
...
@@ -178,7 +178,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -178,7 +178,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{
// e5m2
{
// e5m2
if
((
x
&
0x3
)
==
0
)
if
((
x
&
0x3
)
==
0
)
{
{
if
(
clip
)
if
constexpr
(
clip
)
{
{
return
sign
?
fmin
:
fmax
;
return
sign
?
fmin
:
fmax
;
}
}
...
@@ -194,7 +194,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -194,7 +194,7 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
type
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
{
retval
=
x
<<
8
;
retval
=
x
<<
8
;
return
bit_cast
<
T
>
(
retval
);
return
bit_cast
<
T
>
(
retval
);
...
@@ -228,10 +228,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -228,10 +228,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
if
constexpr
(
sizeof
(
T
)
==
2
)
if
constexpr
(
sizeof
(
T
)
==
2
)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
else
if
(
sizeof
(
T
)
==
4
)
else
if
constexpr
(
sizeof
(
T
)
==
4
)
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
else
else
retval
=
(
sign
<<
63
)
|
(
static_cast
<
unsigned
long
long
>
(
exponent
)
<<
52
)
|
mantissa
;
retval
=
(
sign
<<
63
)
|
(
static_cast
<
unsigned
long
long
>
(
exponent
)
<<
52
)
|
mantissa
;
return
bit_cast
<
T
>
(
retval
);
return
bit_cast
<
T
>
(
retval
);
}
}
...
@@ -498,7 +499,7 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
...
@@ -498,7 +499,7 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng =
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
}
}
}
else
if
(
interpret
==
CK_E4M3_OCP
)
else
if
constexpr
(
interpret
==
CK_E4M3_OCP
)
{
// OCP type
{
// OCP type
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
{
/// propagate NAN/INF, no clipping
...
@@ -575,7 +576,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -575,7 +576,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
fInf
=
0x7FF0000000000000ull
;
fInf
=
0x7FF0000000000000ull
;
mask
=
0x7FFFFFFFFFFFFFFFull
;
mask
=
0x7FFFFFFFFFFFFFFFull
;
}
}
else
if
(
sizeof
(
T
)
==
4
)
else
if
constexpr
(
sizeof
(
T
)
==
4
)
{
{
head
=
x
&
0xFF800000
;
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
mantissa
=
x
&
0x7FFFFF
;
...
@@ -604,7 +605,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -604,7 +605,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
}
}
else
else
{
{
if
(
we
==
4
)
if
constexpr
(
we
==
4
)
{
// e4m3
{
// e4m3
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7e
:
0x7f
);
signed_inf
=
(
sign
<<
7
)
+
(
clip
?
0x7e
:
0x7f
);
}
}
...
@@ -618,13 +619,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -618,13 +619,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
unsigned
long
long
ifmax
=
0
;
unsigned
long
long
ifmax
=
0
;
if
constexpr
(
sizeof
(
T
)
==
8
)
if
constexpr
(
sizeof
(
T
)
==
8
)
{
{
if
(
we
==
5
)
if
constexpr
(
we
==
5
)
{
// 57344
{
// 57344
ifmax
=
0x40EC000000000000ull
;
ifmax
=
0x40EC000000000000ull
;
}
}
else
else
{
{
if
(
is_fnuz
)
if
constexpr
(
is_fnuz
)
{
// 240
{
// 240
ifmax
=
0x406E000000000000ull
;
ifmax
=
0x406E000000000000ull
;
}
}
...
@@ -636,13 +637,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -636,13 +637,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
}
}
else
if
(
sizeof
(
T
)
==
4
)
else
if
(
sizeof
(
T
)
==
4
)
{
{
if
(
we
==
5
)
if
constexpr
(
we
==
5
)
{
{
ifmax
=
0x47600000
;
ifmax
=
0x47600000
;
}
}
else
else
{
{
if
(
is_fnuz
)
if
constexpr
(
is_fnuz
)
{
{
ifmax
=
0x43700000
;
ifmax
=
0x43700000
;
}
}
...
@@ -654,13 +655,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -654,13 +655,13 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
}
}
else
else
{
{
if
(
we
==
5
)
if
constexpr
(
we
==
5
)
{
{
ifmax
=
0x7B00
;
ifmax
=
0x7B00
;
}
}
else
else
{
{
if
(
is_fnuz
)
if
constexpr
(
is_fnuz
)
{
{
ifmax
=
0x5B80
;
ifmax
=
0x5B80
;
}
}
...
@@ -673,7 +674,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -673,7 +674,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
// Deal with inf and NaNs
// Deal with inf and NaNs
if
((
x
&
fInf
)
==
fInf
)
if
((
x
&
fInf
)
==
fInf
)
{
{
if
(
is_fnuz
)
if
constexpr
(
is_fnuz
)
return
signed_inf
;
return
signed_inf
;
return
mantissa
!=
0
?
nan
:
signed_inf
;
return
mantissa
!=
0
?
nan
:
signed_inf
;
...
@@ -788,7 +789,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -788,7 +789,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
const
int
max_exp
=
(
1
<<
we
)
-
1
;
const
int
max_exp
=
(
1
<<
we
)
-
1
;
if
(
f8_exponent
>
max_exp
)
if
(
f8_exponent
>
max_exp
)
{
{
if
(
clip
)
if
constexpr
(
clip
)
{
{
mantissa
=
(
1
<<
wm
)
-
1
;
mantissa
=
(
1
<<
wm
)
-
1
;
f8_exponent
=
max_exp
;
f8_exponent
=
max_exp
;
...
@@ -846,15 +847,15 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
...
@@ -846,15 +847,15 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{
{
return
cast_to_f8
<
float
,
3
,
4
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
3
,
4
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
else
if
(
interp
==
CK_E5M2_FNUZ
)
else
if
constexpr
(
interp
==
CK_E5M2_FNUZ
)
{
{
return
cast_to_f8
<
float
,
2
,
5
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
2
,
5
,
true
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
else
if
(
interp
==
CK_E4M3_OCP
)
else
if
constexpr
(
interp
==
CK_E4M3_OCP
)
{
{
return
cast_to_f8
<
float
,
3
,
4
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
3
,
4
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
else
if
(
interp
==
CK_E5M2_OCP
)
else
if
constexpr
(
interp
==
CK_E5M2_OCP
)
{
{
return
cast_to_f8
<
float
,
2
,
5
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
return
cast_to_f8
<
float
,
2
,
5
,
false
,
sat
==
CK_SATFINITE
,
stochastic_rounding
>
(
f
,
rng
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment