Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a81bc31
Commit
7a81bc31
authored
Dec 02, 2025
by
zhuwenwen
Browse files
update fp8 native implementation
parent
98a011e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
402 additions
and
366 deletions
+402
-366
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+400
-364
vllm/utils/__init__.py
vllm/utils/__init__.py
+2
-2
No files found.
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
7a81bc31
...
...
@@ -13,7 +13,41 @@ namespace vllm {
#ifdef USE_ROCM
namespace
fp8
{
// #ifdef ENABLE_FP8
#ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template
<
typename
fp8_type
>
__device__
__forceinline__
fp8_type
cvt_c10
(
float
const
r
)
{
return
{};
}
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
// the new HW cvt with something reasonable that doesn't rely on the
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
cvt_c10
(
float
const
r
)
{
#if HIP_FP8_TYPE_OCP
return
c10
::
Float8_e4m3fn
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3
::
__default_saturation
,
__hip_fp8_e4m3
::
__default_interpret
),
c10
::
Float8_e4m3fn
::
from_bits
());
#else
// Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
// HW cvt above is faster when it is available (ROCm 6.3 or newer).
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#endif
}
template
<
>
__device__
__forceinline__
c10
::
Float8_e4m3fnuz
cvt_c10
(
float
const
r
)
{
return
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
r
,
__hip_fp8_e4m3_fnuz
::
__default_saturation
,
__hip_fp8_e4m3_fnuz
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
}
// KV-CACHE int8
static
inline
__device__
float
fp8_to_float
(
uint8_t
input
)
{
...
...
@@ -53,10 +87,11 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return
result
;
}
// template <typename Tout, typename Tin>
// __inline__ __device__ Tout vec_conversion(const Tin& x) {
// return x;
// }
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
vec_conversion
(
const
Tin
&
x
)
{
return
x
;
}
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
...
...
@@ -64,271 +99,271 @@ __inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
return
x
;
}
//
#if HIP_FP8_TYPE_OCP
//
using fp8_type = __hip_fp8_e4m3;
//
using fp8x2_type = __hip_fp8x2_e4m3;
//
#else
//
using fp8_type = __hip_fp8_e4m3_fnuz;
//
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
//
#endif
//
// fp8 -> half
//
template <>
//
__inline__ __device__ uint16_t
//
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
//
return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
//
}
//
// fp8x2 -> half2
//
template <>
//
__inline__ __device__ uint32_t
//
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
//
union {
//
__half2_raw h2r;
//
uint32_t ui32;
//
} tmp;
//
tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
//
return tmp.ui32;
//
}
//
// fp8x4 -> half2x2
//
template <>
//
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
//
union {
//
uint2 u32x2;
//
uint32_t u32[2];
//
} tmp;
//
tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
//
tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
//
return tmp.u32x2;
//
}
//
// fp8x8 -> half2x4
//
template <>
//
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
//
union {
//
uint4 u64x2;
//
uint2 u64[2];
//
} tmp;
//
tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
//
tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
//
return tmp.u64x2;
//
}
//
using __nv_bfloat16 = __hip_bfloat16;
//
// fp8 -> __nv_bfloat16
//
template <>
//
__inline__ __device__ __nv_bfloat16
//
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
//
fp8_type f8;
//
f8.__x = a;
//
return __float2bfloat16(static_cast<float>(f8));
//
}
//
using __nv_bfloat162 = __hip_bfloat162;
//
// fp8x2 -> __nv_bfloat162
//
template <>
//
__inline__ __device__ __nv_bfloat162
//
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
//
__nv_bfloat162 res;
//
res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
//
res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
//
return res;
//
}
//
// fp8x4 -> bf16_4_t
//
template <>
//
__inline__ __device__ bf16_4_t
//
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
//
bf16_4_t res;
//
res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
//
res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
//
return res;
//
}
//
// fp8x8 -> bf16_8_t
//
template <>
//
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
//
bf16_4_t tmp1, tmp2;
//
tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
//
tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
//
bf16_8_t res;
//
res.x = tmp1.x;
//
res.y = tmp1.y;
//
res.z = tmp2.x;
//
res.w = tmp2.y;
//
return res;
//
}
//
// fp8 -> float
//
template <>
//
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
//
fp8_type f8;
//
f8.__x = a;
//
return static_cast<float>(f8);
//
}
//
// fp8x2 -> float2
//
template <>
//
__inline__ __device__ float2
//
vec_conversion<float2, uint16_t>(const uint16_t& a) {
//
fp8x2_type f8x2;
//
f8x2.__x = a;
//
return static_cast<float2>(f8x2);
//
}
//
// fp8x4 -> float4
//
template <>
//
__inline__ __device__ Float4_
//
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
//
Float4_ res;
//
res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
//
res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
//
return res;
//
}
//
// fp8x4 -> float4
//
template <>
//
__inline__ __device__ float4
//
vec_conversion<float4, uint32_t>(const uint32_t& a) {
//
Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
//
float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
//
return res;
//
}
//
// fp8x8 -> float8
//
template <>
//
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
//
Float4_ tmp1, tmp2;
//
tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
//
tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
//
Float8_ res;
//
res.x = tmp1.x;
//
res.y = tmp1.y;
//
res.z = tmp2.x;
//
res.w = tmp2.y;
//
return res;
//
}
//
// half -> fp8
//
template <>
//
__inline__ __device__ uint8_t
//
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
//
__half_raw tmp;
//
tmp.x = a;
//
return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
template <>
//
__inline__ __device__ uint16_t
//
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
//
union {
//
uint32_t ui32;
//
__half2_raw h2r;
//
} tmp;
//
tmp.ui32 = a;
//
return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
// bf16 -> fp8
//
template <>
//
__inline__ __device__ uint8_t
//
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
//
return __hip_cvt_float_to_fp8(__bfloat162float(a),
//
fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
// float -> fp8
//
template <>
//
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
//
return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
//
fp8_type::__default_interpret);
//
}
//
// float2 -> half2
//
template <>
//
__inline__ __device__ uint32_t
//
vec_conversion<uint32_t, float2>(const float2& a) {
//
union {
//
half2 float16;
//
uint32_t uint32;
//
};
//
float16 = __float22half2_rn(a);
//
return uint32;
//
}
//
// Float4 -> half2x2
//
template <>
//
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
//
uint2 b;
//
float2 val;
//
val.x = a.x.x;
//
val.y = a.x.y;
//
b.x = vec_conversion<uint32_t, float2>(val);
//
val.x = a.y.x;
//
val.y = a.y.y;
//
b.y = vec_conversion<uint32_t, float2>(val);
//
return b;
//
}
//
// Float4 -> float4
//
template <>
//
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
//
float4 b;
//
b.x = a.x.x;
//
b.y = a.x.y;
//
b.z = a.y.x;
//
b.w = a.y.y;
//
return b;
//
}
//
// Float8 -> half2x4
//
template <>
//
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
//
uint4 b;
//
b.x = vec_conversion<uint32_t, float2>(a.x);
//
b.y = vec_conversion<uint32_t, float2>(a.y);
//
b.z = vec_conversion<uint32_t, float2>(a.z);
//
b.w = vec_conversion<uint32_t, float2>(a.w);
//
return b;
//
}
//
// float2 -> bfloat162
//
template <>
//
__inline__ __device__ __nv_bfloat162
//
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
//
__nv_bfloat162 b = __float22bfloat162_rn(a);
//
return b;
//
}
//
// Float4 -> bfloat162x2
//
template <>
//
__inline__ __device__ bf16_4_t
//
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
//
bf16_4_t b;
//
b.x = __float22bfloat162_rn(a.x);
//
b.y = __float22bfloat162_rn(a.y);
//
return b;
//
}
//
// Float8 -> bfloat162x4
//
template <>
//
__inline__ __device__ bf16_8_t
//
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
//
bf16_8_t b;
//
b.x = __float22bfloat162_rn(a.x);
//
b.y = __float22bfloat162_rn(a.y);
//
b.z = __float22bfloat162_rn(a.z);
//
b.w = __float22bfloat162_rn(a.w);
//
return b;
//
}
#if HIP_FP8_TYPE_OCP
using
fp8_type
=
__hip_fp8_e4m3
;
using
fp8x2_type
=
__hip_fp8x2_e4m3
;
#else
using
fp8_type
=
__hip_fp8_e4m3_fnuz
;
using
fp8x2_type
=
__hip_fp8x2_e4m3_fnuz
;
#endif
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
)
{
return
__hip_cvt_fp8_to_halfraw
(
a
,
fp8_type
::
__default_interpret
).
x
;
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
union
{
__half2_raw
h2r
;
uint32_t
ui32
;
}
tmp
;
tmp
.
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
return
tmp
.
ui32
;
}
// fp8x4 -> half2x2
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
)
{
union
{
uint2
u32x2
;
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
);
tmp
.
u32
[
1
]
=
vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
tmp
.
u32x2
;
}
// fp8x8 -> half2x4
template
<
>
__inline__
__device__
uint4
vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
)
{
union
{
uint4
u64x2
;
uint2
u64
[
2
];
}
tmp
;
tmp
.
u64
[
0
]
=
vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
);
tmp
.
u64
[
1
]
=
vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
);
return
tmp
.
u64x2
;
}
using
__nv_bfloat16
=
__hip_bfloat16
;
// fp8 -> __nv_bfloat16
template
<
>
__inline__
__device__
__nv_bfloat16
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
)
{
fp8_type
f8
;
f8
.
__x
=
a
;
return
__float2bfloat16
(
static_cast
<
float
>
(
f8
));
}
using
__nv_bfloat162
=
__hip_bfloat162
;
// fp8x2 -> __nv_bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
)
{
__nv_bfloat162
res
;
res
.
x
=
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
);
res
.
y
=
vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
));
return
res
;
}
// fp8x4 -> bf16_4_t
template
<
>
__inline__
__device__
bf16_4_t
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
)
{
bf16_4_t
res
;
res
.
x
=
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
);
res
.
y
=
vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
res
;
}
// fp8x8 -> bf16_8_t
template
<
>
__inline__
__device__
bf16_8_t
vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
)
{
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
);
tmp2
=
vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
);
bf16_8_t
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// fp8 -> float
template
<
>
__inline__
__device__
float
vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
)
{
fp8_type
f8
;
f8
.
__x
=
a
;
return
static_cast
<
float
>
(
f8
);
}
// fp8x2 -> float2
template
<
>
__inline__
__device__
float2
vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
)
{
fp8x2_type
f8x2
;
f8x2
.
__x
=
a
;
return
static_cast
<
float2
>
(
f8x2
);
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
Float4_
vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
res
;
res
.
x
=
vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
);
res
.
y
=
vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
));
return
res
;
}
// fp8x4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
)
{
Float4_
tmp
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
);
float4
res
=
make_float4
(
tmp
.
x
.
x
,
tmp
.
x
.
y
,
tmp
.
y
.
x
,
tmp
.
y
.
y
);
return
res
;
}
// fp8x8 -> float8
template
<
>
__inline__
__device__
Float8_
vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
)
{
Float4_
tmp1
,
tmp2
;
tmp1
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
);
tmp2
=
vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
);
Float8_
res
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
z
=
tmp2
.
x
;
res
.
w
=
tmp2
.
y
;
return
res
;
}
// half -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
)
{
__half_raw
tmp
;
tmp
.
x
=
a
;
return
__hip_cvt_halfraw_to_fp8
(
tmp
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
template
<
>
__inline__
__device__
uint16_t
vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
)
{
union
{
uint32_t
ui32
;
__half2_raw
h2r
;
}
tmp
;
tmp
.
ui32
=
a
;
return
__hip_cvt_halfraw2_to_fp8x2
(
tmp
.
h2r
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// bf16 -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
)
{
return
__hip_cvt_float_to_fp8
(
__bfloat162float
(
a
),
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// float -> fp8
template
<
>
__inline__
__device__
uint8_t
vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
)
{
return
__hip_cvt_float_to_fp8
(
a
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
}
// float2 -> half2
template
<
>
__inline__
__device__
uint32_t
vec_conversion
<
uint32_t
,
float2
>
(
const
float2
&
a
)
{
union
{
half2
float16
;
uint32_t
uint32
;
};
float16
=
__float22half2_rn
(
a
);
return
uint32
;
}
// Float4 -> half2x2
template
<
>
__inline__
__device__
uint2
vec_conversion
<
uint2
,
Float4_
>
(
const
Float4_
&
a
)
{
uint2
b
;
float2
val
;
val
.
x
=
a
.
x
.
x
;
val
.
y
=
a
.
x
.
y
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
val
.
x
=
a
.
y
.
x
;
val
.
y
=
a
.
y
.
y
;
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
val
);
return
b
;
}
// Float4 -> float4
template
<
>
__inline__
__device__
float4
vec_conversion
<
float4
,
Float4_
>
(
const
Float4_
&
a
)
{
float4
b
;
b
.
x
=
a
.
x
.
x
;
b
.
y
=
a
.
x
.
y
;
b
.
z
=
a
.
y
.
x
;
b
.
w
=
a
.
y
.
y
;
return
b
;
}
// Float8 -> half2x4
template
<
>
__inline__
__device__
uint4
vec_conversion
<
uint4
,
Float8_
>
(
const
Float8_
&
a
)
{
uint4
b
;
b
.
x
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
x
);
b
.
y
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
y
);
b
.
z
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
z
);
b
.
w
=
vec_conversion
<
uint32_t
,
float2
>
(
a
.
w
);
return
b
;
}
// float2 -> bfloat162
template
<
>
__inline__
__device__
__nv_bfloat162
vec_conversion
<
__nv_bfloat162
,
float2
>
(
const
float2
&
a
)
{
__nv_bfloat162
b
=
__float22bfloat162_rn
(
a
);
return
b
;
}
// Float4 -> bfloat162x2
template
<
>
__inline__
__device__
bf16_4_t
vec_conversion
<
bf16_4_t
,
Float4_
>
(
const
Float4_
&
a
)
{
bf16_4_t
b
;
b
.
x
=
__float22bfloat162_rn
(
a
.
x
);
b
.
y
=
__float22bfloat162_rn
(
a
.
y
);
return
b
;
}
// Float8 -> bfloat162x4
template
<
>
__inline__
__device__
bf16_8_t
vec_conversion
<
bf16_8_t
,
Float8_
>
(
const
Float8_
&
a
)
{
bf16_8_t
b
;
b
.
x
=
__float22bfloat162_rn
(
a
.
x
);
b
.
y
=
__float22bfloat162_rn
(
a
.
y
);
b
.
z
=
__float22bfloat162_rn
(
a
.
z
);
b
.
w
=
__float22bfloat162_rn
(
a
.
w
);
return
b
;
}
/* Scaled and vectorized conversions, for data exchange between high and low
precision domains
...
...
@@ -345,11 +380,10 @@ using __nv_bfloat16 = __hip_bfloat16;
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
// fp8_type f8;
// f8.__x = a;
// return __float2bfloat16(static_cast<float>(f8) * scale);
fp8_type
f8
;
f8
.
__x
=
a
;
return
__float2bfloat16
(
static_cast
<
float
>
(
f8
)
*
scale
);
// return __float2bfloat16(fp8_to_float(a) * scale);
}
// fp8x2 -> __nv_bfloat162
...
...
@@ -394,24 +428,24 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
template
<
>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
fp8_to_float
(
a
)
*
scale
;
// fp8_type f8
;
// f8.__x = a
;
// return
static_cast<
float
>(f8
) * scale;
fp8_type
f8
;
f8
.
__x
=
a
;
return
static_cast
<
float
>
(
f8
)
*
scale
;
// return
fp8_to_
float
(a
) * scale;
}
// fp8x2 -> float2
template
<
>
__inline__
__device__
float2
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float2
f2r
;
f
2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
)
;
f
2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
)
;
return
f2r
;
//
[[maybe_unused]]
// f
p8x2_type f8x2
;
// f
8x2.__x = a
;
// return
static_cast<float2>(f8x2) * scale
;
// [[maybe_unused]]
f
p8x2_type
f8x2
;
f
8x2
.
__x
=
a
;
return
static_cast
<
float2
>
(
f8x2
)
*
scale
;
//
float2 f2r;
// f
2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale)
;
// f
2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale)
;
// return
f2r
;
}
// fp8x4 -> float4
...
...
@@ -451,34 +485,35 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
float
res
=
fp8_to_float
(
a
)
*
scale
;
re
turn
float_to_half
(
res
);
// __half_raw
res;
//
res.data = scaled_vec_conversion<float, uint8_t>(a,
scale
)
;
// return res
.x
;
__half_raw
res
;
re
s
.
data
=
scaled_vec_conversion
<
float
,
uint8_t
>
(
a
,
scale
);
return
res
.
x
;
//
float res = fp8_to_float(a) *
scale;
// return
float_to_half(
res
)
;
}
// fp8x2 -> half2
template
<
>
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
union
{
uint16_t
u16
[
2
];
uint32_t
u32
;
}
res
;
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
return
res
.
u32
;
// [[maybe_unused]] __half2_raw h2r =
// __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
union
{
__half2_raw
h2r
;
uint32_t
ui32
;
}
tmp
;
tmp
.
h2r
=
__hip_cvt_fp8x2_to_halfraw2
(
a
,
fp8_type
::
__default_interpret
);
tmp
.
h2r
.
x
.
data
*=
scale
;
tmp
.
h2r
.
y
.
data
*=
scale
;
return
tmp
.
ui32
;
// union {
// __half2_raw h2r;
// uint32_t ui32;
// } tmp;
// tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
// tmp.h2r.x.data *= scale;
// tmp.h2r.y.data *= scale;
// return tmp.ui32;
// uint16_t u16[2];
// uint32_t u32;
// } res;
// res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
// res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
// return res.u32;
}
// fp8x4 -> half2x2
...
...
@@ -490,7 +525,8 @@ scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
uint32_t
u32
[
2
];
}
tmp
;
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
return
tmp
.
u32x2
;
}
...
...
@@ -511,40 +547,40 @@ __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8
(
res_f
)
;
// __half_raw tmp
;
// tmp.x = a;
// tmp.data /= scale
;
//
return __hip_cvt_
half
raw
_to_f
p8(tmp, fp8_type::__default_saturation,
//
fp8_type::__default_interpret
);
__half_raw
tmp
;
tmp
.
x
=
a
;
tmp
.
data
/=
scale
;
return
__hip_cvt_halfraw_to_fp8
(
tmp
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
)
;
//
float res_f =
half_to_f
loat(a) / scale;
//
return float_to_fp8(res_f
);
}
// halfx2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
union
{
uint8_t
ui8
[
2
];
uint16_t
ui16
;
}
tmp
;
union
{
uint32_t
ui32
;
half2
h2r
;
}
tmp_a
;
tmp_a
.
ui32
=
a
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
);
return
tmp
.
ui16
;
__half2_raw
h2r
;
}
tmp
;
tmp
.
ui32
=
a
;
tmp
.
h2r
.
x
.
data
/=
scale
;
tmp
.
h2r
.
y
.
data
/=
scale
;
return
__hip_cvt_halfraw2_to_fp8x2
(
tmp
.
h2r
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
// union {
// uint
32
_t ui
32
;
//
__half2_raw h2r
;
// uint
8
_t ui
8[2]
;
//
uint16_t ui16
;
// } tmp;
// tmp.ui32 = a;
// tmp.h2r.x.data /= scale;
// tmp.h2r.y.data /= scale;
// return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
// fp8_type::__default_interpret);
// union {
// uint32_t ui32;
// half2 h2r;
// } tmp_a;
// tmp_a.ui32 = a;
// tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
// tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
// return tmp.ui16;
}
// half2x2 -> fp8x4
...
...
@@ -579,11 +615,11 @@ __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
)
)
/
scale
;
return
float_to_fp8
(
res_f
);
// return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
//
fp8_type::__default_saturation,
//
fp8_type::__default_interpret
);
return
__hip_cvt_float_to_fp8
(
__bfloat162
float
(
a
)
/
scale
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
);
//
float res_f = (static_cast<float>(a)) / scale;
//
return float_to_fp8(res_f
);
}
// bf16x2 -> fp8x2
...
...
@@ -626,24 +662,24 @@ scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
return
float_to_fp8
(
a
/
scale
);
// return __hip_cvt_float_to_fp8(a / scale,
fp8_type::__default_
saturation,
//
fp8_type::__default_interpret
);
return
__hip_cvt_
float_to_fp8
(
a
/
scale
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_
interpret
);
//
return float_to_fp8(a / scale
);
}
// floatx2 -> fp8x2
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
)
{
union
{
uint8_t
ui8
[
2
]
;
uint16_t
ui16
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
)
;
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
)
;
return
tmp
.
ui
16
;
//
return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
//
fp8_type::__default_interpret)
;
return
__hip_cvt_float2_to_fp8x2
(
a
/
scale
,
fp8_type
::
__default_saturation
,
fp8_type
::
__default_interpret
)
;
// union {
// uint8_t ui8[2]
;
// uint16_t ui16
;
// } tmp
;
//
tmp.ui
8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale)
;
//
tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
//
return tmp.ui16
;
}
// floatx4 -> fp8x4
...
...
@@ -658,27 +694,27 @@ scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
return
tmp
.
ui32
;
}
//
#endif // ENABLE_FP8
#endif // ENABLE_FP8
//
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
//
__inline__ __device__ Tout convert(const Tin& x) {
//
#ifdef ENABLE_FP8
//
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
//
return vec_conversion<Tout, Tin>(x);
//
}
//
#endif
//
assert(false);
//
return {}; // Squash missing return statement warning
//
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
convert
(
const
Tin
&
x
)
{
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
vec_conversion
<
Tout
,
Tin
>
(
x
);
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
//
#ifdef ENABLE_FP8
//
if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
#ifdef ENABLE_FP8
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
//
}
//
#endif
//
assert(false);
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
...
...
vllm/utils/__init__.py
View file @
7a81bc31
...
...
@@ -132,8 +132,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
"fp8"
:
torch
.
uint8
,
#
"fp8_e4m3": torch.uint8,
#
"fp8_e5m2": torch.uint8,
"fp8_e4m3"
:
torch
.
uint8
,
"fp8_e5m2"
:
torch
.
uint8
,
"int8"
:
torch
.
int8
,
"fp8_inc"
:
torch
.
float8_e4m3fn
,
"fp8_ds_mla"
:
torch
.
uint8
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment