Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aabeb268
Unverified
Commit
aabeb268
authored
Feb 25, 2025
by
Gregory Shtrasberg
Committed by
GitHub
Feb 25, 2025
Browse files
[ROCm][Quantization][Kernel] Using HIP FP8 header (#12593)
parent
2f42a488
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
267 additions
and
634 deletions
+267
-634
CMakeLists.txt
CMakeLists.txt
+19
-0
csrc/quantization/fp8/amd/hip_float8.h
csrc/quantization/fp8/amd/hip_float8.h
+0
-137
csrc/quantization/fp8/amd/hip_float8_impl.h
csrc/quantization/fp8/amd/hip_float8_impl.h
+0
-315
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+230
-168
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+5
-3
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+13
-11
No files found.
CMakeLists.txt
View file @
aabeb268
...
@@ -174,6 +174,25 @@ include(FetchContent)
...
@@ -174,6 +174,25 @@ include(FetchContent)
file
(
MAKE_DIRECTORY
${
FETCHCONTENT_BASE_DIR
}
)
# Ensure the directory exists
file
(
MAKE_DIRECTORY
${
FETCHCONTENT_BASE_DIR
}
)
# Ensure the directory exists
message
(
STATUS
"FetchContent base directory:
${
FETCHCONTENT_BASE_DIR
}
"
)
message
(
STATUS
"FetchContent base directory:
${
FETCHCONTENT_BASE_DIR
}
"
)
#
# Set rocm version dev int.
#
if
(
VLLM_GPU_LANG STREQUAL
"HIP"
)
#
# Overriding the default -O set up by cmake, adding ggdb3 for the most verbose devug info
#
set
(
CMAKE_
${
VLLM_GPU_LANG
}
_FLAGS_DEBUG
"
${
CMAKE_
${
VLLM_GPU_LANG
}
_FLAGS_DEBUG
}
-O0 -ggdb3"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-O0 -ggdb3"
)
#
# Certain HIP functions are marked as [[nodiscard]], yet vllm ignores the result which generates
# a lot of warnings that always mask real issues. Suppressing until this is properly addressed.
#
set
(
CMAKE_
${
VLLM_GPU_LANG
}
_FLAGS
"
${
CMAKE_
${
VLLM_GPU_LANG
}
_FLAGS
}
-Wno-unused-result"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wno-unused-result"
)
endif
()
#
#
# Define other extension targets
# Define other extension targets
#
#
...
...
csrc/quantization/fp8/amd/hip_float8.h
deleted
100644 → 0
View file @
2f42a488
#pragma once
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#else
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#endif
#include "hip_float8_impl.h"
struct
alignas
(
1
)
hip_fp8
{
struct
from_bits_t
{};
HIP_FP8_HOST_DEVICE
static
constexpr
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
uint8_t
data
;
hip_fp8
()
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
const
hip_fp8
&
)
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
)
=
delete
;
explicit
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
,
from_bits_t
)
:
data
(
v
)
{}
#ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias
explicit
HIP_FP8_DEVICE
hip_fp8
(
float
v
)
:
data
(
hip_fp8_impl
::
to_fp8_from_fp32
(
v
))
{}
explicit
HIP_FP8_DEVICE
hip_fp8
(
_Float16
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{}
// Host only implementation using s/w simulation
explicit
HIP_FP8_HOST
#else // __HIP__MI300__
// both Host and DEVICE for non-MI300 using s/w simulation
explicit
HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__
hip_fp8
(
float
v
)
{
data
=
hip_fp8_impl
::
to_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
);
}
explicit
HIP_FP8_HOST_DEVICE
hip_fp8
(
double
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{}
#ifdef __HIP__MI300__
// upcast using device specific intrinsic
explicit
inline
HIP_FP8_DEVICE
operator
float
()
const
{
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
data
);
// upcast
asm
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
return
fval
;
}
explicit
inline
HIP_FP8_HOST
operator
float
()
const
#else // __HIP__MI300__
explicit
inline
HIP_FP8_HOST_DEVICE
operator
float
()
const
#endif // __HIP__MI300__
{
return
hip_fp8_impl
::
from_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
};
namespace
std
{
inline
hip_fp8
sin
(
hip_fp8
a
)
{
return
hip_fp8
(
sinf
(
float
(
a
)));
}
inline
hip_fp8
cos
(
hip_fp8
a
)
{
return
hip_fp8
(
cosf
(
float
(
a
)));
}
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
real
(
const
hip_fp8
&
a
)
{
return
a
;
}
}
// namespace std
// Special operator overloading
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
hip_fp8
&
f8
)
{
return
os
<<
float
(
f8
);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns
// float
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
const
float
fa
,
hip_fp8
b
)
{
return
(
fa
+
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
hip_fp8
a
,
const
float
fb
)
{
return
(
float
(
a
)
+
fb
);
}
inline
HIP_FP8_HOST_DEVICE
hip_fp8
operator
+
(
hip_fp8
a
,
hip_fp8
b
)
{
return
hip_fp8
(
float
(
a
)
+
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
hip_fp8
&
operator
+=
(
hip_fp8
&
a
,
hip_fp8
b
)
{
return
a
=
hip_fp8
(
float
(
a
)
+
float
(
b
));
}
// overloading multiplication, always returns float,
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
hip_fp8
b
)
{
return
float
(
a
)
*
float
(
b
);
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
float
a
,
hip_fp8
b
)
{
return
(
a
*
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
float
b
)
{
return
(
float
(
a
)
*
b
);
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
int32_t
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
double
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
}
// overloading for compare
inline
HIP_FP8_HOST_DEVICE
bool
operator
==
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
==
b
.
data
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
!=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
!=
b
.
data
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
>=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>=
static_cast
<
float
>
(
b
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
>
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
}
csrc/quantization/fp8/amd/hip_float8_impl.h
deleted
100644 → 0
View file @
2f42a488
#pragma once
#if defined(__HIPCC__) && defined(__gfx942__)
#define __HIP__MI300__
#endif
#ifdef __HIPCC__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#else
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#endif
namespace
hip_fp8_impl
{
#ifdef __HIP__MI300__
HIP_FP8_DEVICE
uint8_t
to_fp8_from_fp32
(
float
v
)
{
uint8_t
i8data
;
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// NOTE: not endian independent
}
val
;
uint32_t
ival
=
0
;
val
.
fval
=
v
;
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
return
i8data
;
}
#endif // __HIP__MI300__
HIP_FP8_HOST
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
HIP_FP8_DEVICE
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
#endif
template
<
int
we
,
int
wm
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
HIP_FP8_HOST_DEVICE
uint8_t
to_float8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
#ifdef __HIPCC__
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
#else
constexpr
bool
is_half
=
false
;
#endif
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
if
(
sizeof
(
T
)
==
4
)
{
x
=
reinterpret_cast
<
uint32_t
&>
(
_x
);
}
else
{
x
=
reinterpret_cast
<
uint16_t
&>
(
_x
);
}
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
uint32_t
sign
;
if
(
sizeof
(
T
)
==
4
)
{
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
bias
=
127
;
}
else
{
head
=
x
&
0xFC00
;
mantissa
=
x
&
0x3FF
;
exponent
=
(
head
>>
10
)
&
0x1F
;
sign
=
head
>>
15
;
bias
=
15
;
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
// Deal with inf and NaNs
if
(
negative_zero_nan
)
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
{
return
0x80
;
}
}
else
{
// if(__hisinf(x) || __hisnan(x))
if
((
x
&
0x7C00
)
==
0x7C00
)
{
return
0x80
;
}
}
}
else
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
{
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
else
{
if
((
x
&
0x7C00
)
==
0x7C00
)
{
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
}
if
(
x
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
we
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
f8_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no
// difference for this case, act_exponent could be
// larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
mfmt
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
mfmt
-
wm
+
exponent_diff
))
-
1
))
==
static_cast
<
uint32_t
>
(
1
<<
(
mfmt
-
wm
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part
and make something not midpoint look like midpoint. For example, the fp16
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
shift right by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
{
mantissa
>>=
exponent_diff
;
}
else
if
(
exponent_diff
==
-
1
)
{
mantissa
<<=
-
exponent_diff
;
}
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
wm
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit
// that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
f8_exponent
==
0
)
{
if
((
1
<<
mfmt
)
&
mantissa
)
{
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
}
else
{
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
f8_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
wm
);
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
if
(
f8_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
wm
)
-
1
;
f8_exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
}
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
{
return
negative_zero_nan
?
0
:
(
sign
<<
7
);
}
mantissa
&=
(
1
<<
wm
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
wm
)
|
mantissa
;
}
template
<
int
we
,
int
wm
,
typename
T
=
float
,
bool
negative_zero_nan
=
true
>
inline
HIP_FP8_HOST_DEVICE
T
from_float8
(
uint8_t
x
)
{
#ifdef __HIPCC__
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
#else
constexpr
bool
is_half
=
false
;
#endif
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
#ifdef __HIPCC__
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
_Float16
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
_Float16
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
_Float16
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
_Float16
&>
(
ihNeg0
);
}
else
#endif
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
{
return
0
;
}
uint32_t
sign
=
x
>>
7
;
uint32_t
mantissa
=
x
&
((
1
<<
wm
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm
;
if
(
negative_zero_nan
)
{
if
(
x
==
0x80
)
{
return
fNaN
;
}
}
else
{
if
(
x
==
0x80
)
{
return
fNeg0
;
}
if
(
exponent
==
((
1
<<
we
)
-
1
))
{
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
}
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
wm
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
wm
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
wm
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
if
(
sizeof
(
T
)
==
2
)
{
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
}
else
{
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
}
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
}
// namespace hip_fp8_impl
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
aabeb268
#pragma once
#pragma once
#include
"
hip_f
loat
8.h
"
#include
<hip/
hip_f
p
8.h
>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include <hip/hip_bfloat16.h>
#include "../../../attention/dtype_fp8.cuh"
#include "../../../attention/attention_dtypes.h"
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"
namespace
vllm
{
namespace
vllm
{
#ifdef USE_ROCM
#ifdef USE_ROCM
...
@@ -26,40 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
...
@@ -26,40 +24,31 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return
x
;
return
x
;
}
}
#if HIP_FP8_TYPE_FNUZ
using
fp8_type
=
__hip_fp8_e4m3_fnuz
;
using
fp8x2_type
=
__hip_fp8x2_e4m3_fnuz
;
#elif HIP_FP8_TYPE_OCP
using
fp8_type
=
__hip_fp8_e4m3
;
using
fp8x2_type
=
__hip_fp8x2_e4m3
;
#endif
// fp8 -> half
// fp8 -> half
template
<
>
template
<
>
__inline__
__device__
uint16_t
__inline__
__device__
uint16_t
vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
)
{
vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
)
{
hip_fp8
f8
{
a
,
hip_fp8
::
from_bits
()};
return
__hip_cvt_fp8_to_halfraw
(
a
,
fp8_type
::
__default_interpret
).
x
;
__half_raw
res
;
res
.
data
=
static_cast
<
float
>
(
f8
);
return
res
.
x
;
}
}
// fp8x2 -> half2
// fp8x2 -> half2
template
<
>
template
<
>
__inline__
__device__
uint32_t
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const
auto
&
f2
=
__builtin_amdgcn_cvt_pk_f32_fp8
(
a
,
0
);
union
{
union
{
__half2_raw
h2r
;
__half2_raw
h2r
;
uint32_t
ui32
;
uint32_t
ui32
;
}
tmp
;
}
tmp
;
tmp
.
h2r
.
x
.
data
=
f2
[
0
];
tmp
.
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
tmp
.
h2r
.
y
.
data
=
f2
[
1
];
return
tmp
.
ui32
;
return
tmp
.
ui32
;
#else
union
{
uint16_t
u16
[
2
];
uint32_t
u32
;
}
tmp
;
tmp
.
u16
[
0
]
=
vec_conversion
<
uint16_t
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
));
tmp
.
u16
[
1
]
=
vec_conversion
<
uint16_t
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
>>
8U
));
return
tmp
.
u32
;
#endif
}
}
// fp8x4 -> half2x2
// fp8x4 -> half2x2
...
@@ -92,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16;
...
@@ -92,9 +81,9 @@ using __nv_bfloat16 = __hip_bfloat16;
template
<
>
template
<
>
__inline__
__device__
__nv_bfloat16
__inline__
__device__
__nv_bfloat16
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
)
{
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
)
{
hip_fp8
f8
{
a
,
hip_fp8
::
from_bits
()}
;
fp8_type
f8
;
f
loat
f
{
f8
}
;
f
8
.
__x
=
a
;
return
__float2bfloat16
(
f
);
return
__float2bfloat16
(
static_cast
<
float
>
(
f8
)
);
}
}
using
__nv_bfloat162
=
__hip_bfloat162
;
using
__nv_bfloat162
=
__hip_bfloat162
;
...
@@ -136,27 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
...
@@ -136,27 +125,18 @@ __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
// fp8 -> float
// fp8 -> float
template
<
>
template
<
>
__inline__
__device__
float
vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
)
{
__inline__
__device__
float
vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
)
{
hip_fp8
fp8
{
a
,
hip_fp8
::
from_bits
()};
fp8_type
f8
;
return
static_cast
<
float
>
(
fp8
);
f8
.
__x
=
a
;
return
static_cast
<
float
>
(
f8
);
}
}
// fp8x2 -> float2
// fp8x2 -> float2
template
<
>
template
<
>
__inline__
__device__
float2
__inline__
__device__
float2
vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
)
{
vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
)
{
#if defined(__HIP__MI300__) && \
fp8x2_type
f8x2
;
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
f8x2
.
__x
=
a
;
float2
res
;
return
static_cast
<
float2
>
(
f8x2
);
const
auto
&
f2
=
__builtin_amdgcn_cvt_pk_f32_fp8
(
a
,
0
);
res
.
x
=
f2
[
0
];
res
.
y
=
f2
[
1
];
return
res
;
#else
float2
res
;
res
.
x
=
vec_conversion
<
float
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
));
res
.
y
=
vec_conversion
<
float
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
>>
8U
));
return
res
;
#endif
}
}
// fp8x4 -> float4
// fp8x4 -> float4
...
@@ -169,6 +149,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
...
@@ -169,6 +149,15 @@ vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
return
res
;
return
res
;
}
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
tmp
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
// fp8x8 -> float8
// fp8x8 -> float8
template
<
>
template
<
>
__inline__
__device__
Float8_
vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
)
{
__inline__
__device__
Float8_
vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
)
{
...
@@ -189,33 +178,36 @@ __inline__ __device__ uint8_t
...
@@ -189,33 +178,36 @@ __inline__ __device__ uint8_t
vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
__half_raw
tmp
;
__half_raw
tmp
;
tmp
.
x
=
a
;
tmp
.
x
=
a
;
return
__hip_cvt_halfraw_to_fp8
(
tmp
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
hip_fp8
f8
{
static_cast
<
float
>
(
tmp
.
data
)};
template
<
>
return
f8
.
data
;
__inline__
__device__
uint16_t
vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
)
{
union
{
uint32_t
ui32
;
__half2_raw
h2r
;
}
tmp
;
tmp
.
ui32
=
a
;
return
__hip_cvt_halfraw2_to_fp8x2
(
tmp
.
h2r
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
}
// bf16 -> fp8
// bf16 -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
)
{
vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
)
{
hip_fp8
res
{
__bfloat162float
(
a
)};
return
__hip_cvt_float_to_fp8
(
__bfloat162float
(
a
),
return
res
.
data
;
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
}
// float -> fp8
// float -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
)
{
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
)
{
hip_fp8
f8
(
a
);
return
__hip_cvt_float_to_fp8
(
a
,
fp8_type
::
__default_saturation
,
return
f8
.
data
;
fp8_type
::
__default_interpret
);
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
tmp
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
}
// float2 -> half2
// float2 -> half2
...
@@ -307,90 +299,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
...
@@ -307,90 +299,22 @@ vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
*/
*/
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
const
float
scale
)
{
hip_fp8
f8
{
a
,
hip_fp8
::
from_bits
()};
__half_raw
res
;
res
.
data
=
static_cast
<
float
>
(
f8
)
*
scale
;
return
res
.
x
;
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
#if defined(__HIP__MI300__) && \
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
const
auto
&
f2
=
__builtin_amdgcn_cvt_pk_f32_fp8
(
a
,
0
);
union
{
__half2_raw
h2r
;
uint32_t
ui32
;
}
tmp
;
tmp
.
h2r
.
x
.
data
=
f2
[
0
]
*
scale
;
tmp
.
h2r
.
y
.
data
=
f2
[
1
]
*
scale
;
return
tmp
.
ui32
;
#else
union
{
uint16_t
u16
[
2
];
uint32_t
u32
;
}
tmp
;
tmp
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
),
scale
);
tmp
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
>>
8U
),
scale
);
return
tmp
.
u32
;
#endif
}
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
return
tmp
.
u32x2
;
}
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
scaled_vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
);
return
tmp
.
u64x2
;
}
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat16
=
__hip_bfloat16
;
// fp8 -> __nv_bfloat16
// fp8 -> __nv_bfloat16
template
<
>
template
<
>
__inline__
__device__
__nv_bfloat16
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
const
float
scale
)
{
fp8_type
f8
;
hip_fp8
f8
{
a
,
hip_fp8
::
from_bits
()};
f8
.
__x
=
a
;
float
f
{
f8
};
return
__float2bfloat16
(
static_cast
<
float
>
(
f8
)
*
scale
);
return
__float2bfloat16
(
f
*
scale
);
}
}
using
__nv_bfloat162
=
__hip_bfloat162
;
// fp8x2 -> __nv_bfloat162
// fp8x2 -> __nv_bfloat162
template
<
>
template
<
>
__inline__
__device__
__nv_bfloat162
__inline__
__device__
__nv_bfloat162
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
float
scale
)
{
__nv_bfloat162
res
;
__nv_bfloat162
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
y
=
res
.
y
=
...
@@ -400,8 +324,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
...
@@ -400,8 +324,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
// fp8x4 -> bf16_4_t
// fp8x4 -> bf16_4_t
template
<
>
template
<
>
__inline__
__device__
bf16_4_t
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
__inline__
__device__
bf16_4_t
const
uint32_t
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
bf16_4_t
res
;
bf16_4_t
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
...
@@ -412,7 +336,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
...
@@ -412,7 +336,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
// fp8x8 -> bf16_8_t
// fp8x8 -> bf16_8_t
template
<
>
template
<
>
__inline__
__device__
bf16_8_t
__inline__
__device__
bf16_8_t
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
bf16_4_t
tmp1
,
tmp2
;
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
);
...
@@ -427,29 +351,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
...
@@ -427,29 +351,19 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
// fp8 -> float
// fp8 -> float
template
<
>
template
<
>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
const
float
scale
)
{
const
uint8_t
&
a
,
float
scale
)
{
hip_fp8
fp8
{
a
,
hip_fp8
::
from_bits
()};
fp8_type
f8
;
return
static_cast
<
float
>
(
fp8
)
*
scale
;
f8
.
__x
=
a
;
return
static_cast
<
float
>
(
f8
)
*
scale
;
}
}
// fp8x2 -> float2
// fp8x2 -> float2
template
<
>
template
<
>
__inline__
__device__
float2
__inline__
__device__
float2
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
#if defined(__HIP__MI300__) && \
fp8x2_type
f8x2
;
defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
f8x2
.
__x
=
a
;
float2
res
;
return
static_cast
<
float2
>
(
f8x2
)
*
scale
;
const
auto
&
f2
=
__builtin_amdgcn_cvt_pk_f32_fp8
(
a
,
0
);
res
.
x
=
f2
[
0
]
*
scale
;
res
.
y
=
f2
[
1
]
*
scale
;
return
res
;
#else
float2
res
;
res
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
),
scale
);
res
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
(
static_cast
<
uint8_t
>
(
a
>>
8U
),
scale
);
return
res
;
#endif
}
}
// fp8x4 -> float4
// fp8x4 -> float4
...
@@ -462,10 +376,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
...
@@ -462,10 +376,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
return
res
;
return
res
;
}
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
Float4_
res
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
);
return
{
res
.
x
.
x
,
res
.
x
.
y
,
res
.
y
.
x
,
res
.
y
.
y
};
}
// fp8x8 -> float8
// fp8x8 -> float8
template
<
>
template
<
>
__inline__
__device__
Float8_
__inline__
__device__
Float8_
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
Float4_
tmp1
,
tmp2
;
Float4_
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
);
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
);
...
@@ -477,44 +399,184 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
...
@@ -477,44 +399,184 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
return
res
;
return
res
;
}
}
/* Quantize(HP / scale) => FP8 */
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
__half_raw
res
;
res
.
data
=
scaled_vec_conversion
<
float
,
uint8_t
>
(
a
,
scale
);
return
res
.
x
;
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
__half2_raw
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
union
{
__half2_raw
h2r
;
uint32_t
ui32
;
}
tmp
;
tmp
.
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
tmp
.
h2r
.
x
.
data
*=
scale
;
tmp
.
h2r
.
y
.
data
*=
scale
;
return
tmp
.
ui32
;
}
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
return
tmp
.
u32x2
;
}
// TODO(Hai): vectorized to add
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
scaled_vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
);
return
tmp
.
u64x2
;
}
// half -> fp8
// half -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
__half_raw
tmp
;
__half_raw
tmp
;
tmp
.
x
=
a
;
tmp
.
x
=
a
;
tmp
.
data
/=
scale
;
return
__hip_cvt_halfraw_to_fp8
(
tmp
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// halfx2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
union
{
uint32_t
ui32
;
__half2_raw
h2r
;
}
tmp
;
tmp
.
ui32
=
a
;
tmp
.
h2r
.
x
.
data
/=
scale
;
tmp
.
h2r
.
y
.
data
/=
scale
;
return
__hip_cvt_halfraw2_to_fp8x2
(
tmp
.
h2r
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
hip_fp8
f8
{
static_cast
<
float
>
(
tmp
.
data
)
/
scale
};
// half2x2 -> fp8x4
return
f8
.
data
;
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
y
,
scale
);
return
tmp
.
ui32
;
}
// half2x4 -> fp8x8
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint4
>
(
const
uint4
&
a
,
float
scale
)
{
union
{
uint2
ui2
[
2
];
uint4
ui4
;
}
tmp
;
tmp
.
ui4
=
a
;
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
0
],
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
1
],
scale
);
return
res
;
}
}
// bf16 -> fp8
// bf16 -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
const
float
scale
)
{
const
__nv_bfloat16
&
a
,
float
scale
)
{
hip_fp8
res
{
__bfloat162float
(
a
)
/
scale
};
return
__hip_cvt_float_to_fp8
(
__bfloat162float
(
a
)
/
scale
,
return
res
.
data
;
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// bf16x2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
const
__nv_bfloat162
&
a
,
float
scale
)
{
union
{
uint8_t
ui8
[
2
];
uint16_t
ui16
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
y
,
scale
);
return
tmp
.
ui16
;
}
// bf16x4 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
(
const
bf16_4_t
&
a
,
float
scale
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
y
,
scale
);
return
tmp
.
ui32
;
}
// bf16x8 -> fp8x8
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
bf16_8_t
>
(
const
bf16_8_t
&
a
,
float
scale
)
{
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
x
,
a
.
y
},
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
z
,
a
.
w
},
scale
);
return
res
;
}
}
// float -> fp8
// float -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
hip_fp8
f8
(
a
/
scale
);
return
__hip_cvt_float_to_fp8
(
a
/
scale
,
fp8_type
::
__default_saturation
,
return
f8
.
data
;
fp8_type
::
__default_interpret
)
;
}
}
// f
p8x4 -> float4
// f
loatx2 -> fp8x2
template
<
>
template
<
>
__inline__
__device__
float4
__inline__
__device__
uint16_t
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
)
{
Float4_
tmp
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
);
return
__hip_cvt_float2_to_fp8x2
(
a
/
scale
,
fp8_type
::
__default_saturation
,
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
fp8_type
::
__default_interpret
);
return
res
;
}
// floatx4 -> fp8x4
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
float4
>
(
const
float4
&
a
,
float
scale
)
{
union
{
uint16_t
ui16
[
2
];
uint32_t
ui32
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
x
,
a
.
y
},
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
return
tmp
.
ui32
;
}
}
#endif // ENABLE_FP8
#endif // ENABLE_FP8
...
...
csrc/quantization/fp8/common.cuh
View file @
aabeb268
...
@@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
...
@@ -12,7 +12,7 @@ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/
hip_float8.
h"
#include "amd/
quant_utils.cu
h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
// issue when running dynamic quantization. Here use 224.0f for rocm.
...
@@ -47,7 +47,9 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
...
@@ -47,7 +47,9 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#else
#else
// Use hardware cvt instruction for fp8 on rocm
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
hip_fp8
(
r
).
data
,
return
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
r
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
#endif
}
}
...
...
tests/kernels/test_cache.py
View file @
aabeb268
...
@@ -159,19 +159,20 @@ def test_reshape_and_cache(
...
@@ -159,19 +159,20 @@ def test_reshape_and_cache(
device
)
device
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Using default kv_scale
k_scale
=
(
key
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
64.0
).
to
(
torch
.
float32
)
# Clone the KV caches.
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
()
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
()
)
else
:
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
@@ -182,9 +183,9 @@ def test_reshape_and_cache(
...
@@ -182,9 +183,9 @@ def test_reshape_and_cache(
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
,
k_scale
.
item
()
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
,
v_scale
.
item
()
)
# Run the reference implementation.
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
...
@@ -268,15 +269,16 @@ def test_reshape_and_cache_flash(
...
@@ -268,15 +269,16 @@ def test_reshape_and_cache_flash(
del
key_caches
del
key_caches
del
value_caches
del
value_caches
k_scale
=
(
key
.
amax
()
/
25
6.0
).
to
(
torch
.
float32
)
k_scale
=
(
key
.
amax
()
/
6
4
.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
25
6.0
).
to
(
torch
.
float32
)
v_scale
=
(
value
.
amax
()
/
6
4
.0
).
to
(
torch
.
float32
)
# Clone the KV caches.
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
,
kv_cache_dtype
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
.
item
(),
kv_cache_dtype
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
,
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
.
item
()
,
kv_cache_dtype
)
kv_cache_dtype
)
else
:
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_key_cache
=
key_cache
.
clone
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment