Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b3062dab
Commit
b3062dab
authored
Jan 23, 2026
by
zhuwenwen
Browse files
support fa kvcache fp8, add VLLM_USE_QUERY_QUANT to not use q quant(todo)
parent
4e51cae7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
164 additions
and
86 deletions
+164
-86
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+111
-65
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+3
-0
vllm/envs.py
vllm/envs.py
+6
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+44
-21
No files found.
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
b3062dab
...
@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
...
@@ -27,7 +27,7 @@ static inline __device__ float fp8_to_float(uint8_t input) {
}
}
// float -> fp8
// float -> fp8
static
inline
__device__
uint8_t
float_to_fp8
(
float
f
)
{
static
inline
__device__
uint8_t
float_to_fp8
_e4m3
(
float
f
)
{
constexpr
uint32_t
fp8_max
=
UINT32_C
(
1087
)
<<
20
;
constexpr
uint32_t
fp8_max
=
UINT32_C
(
1087
)
<<
20
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
141
)
<<
23
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
141
)
<<
23
;
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
...
@@ -53,10 +53,35 @@ static inline __device__ uint8_t float_to_fp8(float f) {
...
@@ -53,10 +53,35 @@ static inline __device__ uint8_t float_to_fp8(float f) {
return
result
;
return
result
;
}
}
static
inline
__device__
uint8_t
float_to_fp8_e5m2
(
float
f
)
{
constexpr
uint32_t
fp32_inf
=
UINT32_C
(
255
)
<<
23
;
constexpr
uint32_t
fp8_max
=
UINT32_C
(
143
)
<<
23
;
constexpr
uint32_t
denorm_mask
=
UINT32_C
(
134
)
<<
23
;
uint32_t
f_bits
=
c10
::
detail
::
fp32_to_bits
(
f
);
uint8_t
result
=
0u
;
const
uint32_t
sign
=
f_bits
&
UINT32_C
(
0x80000000
);
f_bits
^=
sign
;
if
(
f_bits
>=
fp8_max
)
{
result
=
f_bits
>
fp32_inf
?
UINT8_C
(
0x7F
)
:
UINT8_C
(
0x7C
);
}
else
{
if
(
f_bits
<
(
UINT32_C
(
113
)
<<
23
))
{
f_bits
=
c10
::
detail
::
fp32_to_bits
(
c10
::
detail
::
fp32_from_bits
(
f_bits
)
+
c10
::
detail
::
fp32_from_bits
(
denorm_mask
));
result
=
static_cast
<
uint8_t
>
(
f_bits
-
denorm_mask
);
}
else
{
uint32_t
mant_odd
=
(
f_bits
>>
21
)
&
1
;
f_bits
+=
((
uint32_t
)(
15
-
127
)
<<
23
)
+
0xFFFFF
;
f_bits
+=
mant_odd
;
result
=
static_cast
<
uint8_t
>
(
f_bits
>>
21
);
}
}
result
|=
static_cast
<
uint8_t
>
(
sign
>>
24
);
return
result
;
}
template
<
typename
Tout
,
typename
Tin
>
template
<
typename
Tout
,
typename
Tin
>
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
__inline__
__device__
Tout
scaled_vec_conversion
(
const
Tin
&
x
,
const
float
scale
)
{
const
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
return
x
;
return
x
;
}
}
...
@@ -65,8 +90,10 @@ using __nv_bfloat16 = __hip_bfloat16;
...
@@ -65,8 +90,10 @@ using __nv_bfloat16 = __hip_bfloat16;
// fp8 -> __nv_bfloat16
// fp8 -> __nv_bfloat16
template
<
>
template
<
>
__inline__
__device__
__nv_bfloat16
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
return
__float2bfloat16
(
fp8_to_float
(
a
)
*
scale
);
}
}
...
@@ -74,32 +101,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
...
@@ -74,32 +101,32 @@ scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
template
<
>
template
<
>
__inline__
__device__
__nv_bfloat162
__inline__
__device__
__nv_bfloat162
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
__nv_bfloat162
res
;
__nv_bfloat162
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
res
.
y
=
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
scaled_vec_conversion
<
__nv_bfloat16
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
res
;
return
res
;
}
}
// fp8x4 -> bf16_4_t
// fp8x4 -> bf16_4_t
template
<
>
template
<
>
__inline__
__device__
bf16_4_t
__inline__
__device__
bf16_4_t
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
bf16_4_t
res
;
bf16_4_t
res
;
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
res
.
y
=
scaled_vec_conversion
<
__nv_bfloat162
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
scale
,
kv_type
);
return
res
;
return
res
;
}
}
// fp8x8 -> bf16_8_t
// fp8x8 -> bf16_8_t
template
<
>
template
<
>
__inline__
__device__
bf16_8_t
__inline__
__device__
bf16_8_t
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
bf16_8_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
bf16_4_t
tmp1
,
tmp2
;
bf16_4_t
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp1
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
);
tmp2
=
scaled_vec_conversion
<
bf16_4_t
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
bf16_8_t
res
;
bf16_8_t
res
;
res
.
x
=
tmp1
.
x
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
y
=
tmp1
.
y
;
...
@@ -111,24 +138,27 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
...
@@ -111,24 +138,27 @@ scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
// fp8 -> float
// fp8 -> float
template
<
>
template
<
>
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
__inline__
__device__
float
scaled_vec_conversion
<
float
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
return
fp8_to_float
(
a
)
*
scale
;
return
fp8_to_float
(
a
)
*
scale
;
}
}
// fp8x2 -> float2
// fp8x2 -> float2
template
<
>
template
<
>
__inline__
__device__
float2
__inline__
__device__
float2
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
float2
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float2
f2r
;
float2
f2r
;
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
f2r
.
x
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
f2r
.
y
=
scaled_vec_conversion
<
float
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
f2r
;
return
f2r
;
}
}
// fp8x4 -> float4
// fp8x4 -> float4
template
<
>
template
<
>
__inline__
__device__
Float4_
__inline__
__device__
Float4_
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
)
{
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
const
uint32_t
&
a
,
const
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
res
;
Float4_
res
;
res
.
x
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
x
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
res
.
y
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
res
.
y
=
scaled_vec_conversion
<
float2
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
...
@@ -138,18 +168,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
...
@@ -138,18 +168,18 @@ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
// fp8x4 -> float4
// fp8x4 -> float4
template
<
>
template
<
>
__inline__
__device__
float4
__inline__
__device__
float4
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
float4
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
kv_type
)
{
Float4_
res
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
);
Float4_
res
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
,
scale
,
kv_type
);
return
{
res
.
x
.
x
,
res
.
x
.
y
,
res
.
y
.
x
,
res
.
y
.
y
};
return
{
res
.
x
.
x
,
res
.
x
.
y
,
res
.
y
.
x
,
res
.
y
.
y
};
}
}
// fp8x8 -> float8
// fp8x8 -> float8
template
<
>
template
<
>
__inline__
__device__
Float8_
__inline__
__device__
Float8_
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
Float8_
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
Float4_
tmp1
,
tmp2
;
Float4_
tmp1
,
tmp2
;
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
);
tmp1
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
);
tmp2
=
scaled_vec_conversion
<
Float4_
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
Float8_
res
;
Float8_
res
;
res
.
x
=
tmp1
.
x
;
res
.
x
=
tmp1
.
x
;
res
.
y
=
tmp1
.
y
;
res
.
y
=
tmp1
.
y
;
...
@@ -161,7 +191,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
...
@@ -161,7 +191,10 @@ scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
// fp8 -> half
// fp8 -> half
template
<
>
template
<
>
__inline__
__device__
uint16_t
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
(
const
uint8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E5M2
)
{
assert
(
false
);
}
float
res
=
fp8_to_float
(
a
)
*
scale
;
float
res
=
fp8_to_float
(
a
)
*
scale
;
return
float_to_half
(
res
);
return
float_to_half
(
res
);
}
}
...
@@ -169,54 +202,58 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
...
@@ -169,54 +202,58 @@ scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
// fp8x2 -> half2
// fp8x2 -> half2
template
<
>
template
<
>
__inline__
__device__
uint32_t
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint16_t
u16
[
2
];
uint16_t
u16
[
2
];
uint32_t
u32
;
uint32_t
u32
;
}
res
;
}
res
;
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
);
res
.
u16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)
a
,
scale
,
kv_type
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
);
res
.
u16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint8_t
>
((
uint8_t
)(
a
>>
8U
),
scale
,
kv_type
);
return
res
.
u32
;
return
res
.
u32
;
}
}
// fp8x4 -> half2x2
// fp8x4 -> half2x2
template
<
>
template
<
>
__inline__
__device__
uint2
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint2
u32x2
;
uint2
u32x2
;
uint32_t
u32
[
2
];
uint32_t
u32
[
2
];
}
tmp
;
}
tmp
;
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
);
tmp
.
u32
[
0
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)
a
,
scale
,
kv_type
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
);
tmp
.
u32
[
1
]
=
scaled_vec_conversion
<
uint32_t
,
uint16_t
>
((
uint16_t
)(
a
>>
16U
),
scale
,
kv_type
);
return
tmp
.
u32x2
;
return
tmp
.
u32x2
;
}
}
// fp8x8 -> half2x4
// fp8x8 -> half2x4
template
<
>
template
<
>
__inline__
__device__
uint4
scaled_vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
,
__inline__
__device__
uint4
scaled_vec_conversion
<
uint4
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint4
u64x2
;
uint4
u64x2
;
uint2
u64
[
2
];
uint2
u64
[
2
];
}
tmp
;
}
tmp
;
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
u64
[
0
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
);
tmp
.
u64
[
1
]
=
scaled_vec_conversion
<
uint2
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
u64x2
;
return
tmp
.
u64x2
;
}
}
// half -> fp8
// half -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8
(
res_f
);
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
res_f
);
}
else
{
return
float_to_fp8_e5m2
(
res_f
);
}
}
}
// halfx2 -> fp8x2
// halfx2 -> fp8x2
template
<
>
template
<
>
__inline__
__device__
uint16_t
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
const
uint32_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint8_t
ui8
[
2
];
uint8_t
ui8
[
2
];
uint16_t
ui16
;
uint16_t
ui16
;
...
@@ -226,113 +263,122 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
...
@@ -226,113 +263,122 @@ scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
half2
h2r
;
half2
h2r
;
}
tmp_a
;
}
tmp_a
;
tmp_a
.
ui32
=
a
;
tmp_a
.
ui32
=
a
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
0
],
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
uint16_t
>
(
tmp_a
.
h2r
.
data
[
1
],
scale
,
kv_type
);
return
tmp
.
ui16
;
return
tmp
.
ui16
;
}
}
// half2x2 -> fp8x4
// half2x2 -> fp8x4
template
<
>
template
<
>
__inline__
__device__
uint32_t
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
const
uint2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint16_t
ui16
[
2
];
uint16_t
ui16
[
2
];
uint32_t
ui32
;
uint32_t
ui32
;
}
tmp
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
y
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
uint32_t
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui32
;
return
tmp
.
ui32
;
}
}
// half2x4 -> fp8x8
// half2x4 -> fp8x8
template
<
>
template
<
>
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint4
>
(
const
uint4
&
a
,
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
uint4
>
(
const
uint4
&
a
,
float
scale
)
{
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint2
ui2
[
2
];
uint2
ui2
[
2
];
uint4
ui4
;
uint4
ui4
;
}
tmp
;
}
tmp
;
tmp
.
ui4
=
a
;
tmp
.
ui4
=
a
;
uint2
res
;
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
0
],
scale
);
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
0
],
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
1
],
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
uint2
>
(
tmp
.
ui2
[
1
],
scale
,
kv_type
);
return
res
;
return
res
;
}
}
// bf16 -> fp8
// bf16 -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
const
__nv_bfloat16
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
return
float_to_fp8
(
res_f
);
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
res_f
);
}
else
{
return
float_to_fp8_e5m2
(
res_f
);
}
}
}
// bf16x2 -> fp8x2
// bf16x2 -> fp8x2
template
<
>
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
const
__nv_bfloat162
&
a
,
float
scale
)
{
const
__nv_bfloat162
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint8_t
ui8
[
2
];
uint8_t
ui8
[
2
];
uint16_t
ui16
;
uint16_t
ui16
;
}
tmp
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
y
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
__nv_bfloat16
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui16
;
return
tmp
.
ui16
;
}
}
// bf16x4 -> fp8x4
// bf16x4 -> fp8x4
template
<
>
template
<
>
__inline__
__device__
uint32_t
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
(
const
bf16_4_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
(
const
bf16_4_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint16_t
ui16
[
2
];
uint16_t
ui16
[
2
];
uint32_t
ui32
;
uint32_t
ui32
;
}
tmp
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
x
,
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
y
,
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
__nv_bfloat162
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui32
;
return
tmp
.
ui32
;
}
}
// bf16x8 -> fp8x8
// bf16x8 -> fp8x8
template
<
>
template
<
>
__inline__
__device__
uint2
__inline__
__device__
uint2
scaled_vec_conversion
<
uint2
,
bf16_8_t
>
(
const
bf16_8_t
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint2
,
bf16_8_t
>
(
const
bf16_8_t
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
uint2
res
;
uint2
res
;
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
x
,
a
.
y
},
scale
);
res
.
x
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
x
,
a
.
y
},
scale
,
kv_type
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
z
,
a
.
w
},
scale
);
res
.
y
=
scaled_vec_conversion
<
uint32_t
,
bf16_4_t
>
({
a
.
z
,
a
.
w
},
scale
,
kv_type
);
return
res
;
return
res
;
}
}
// float -> fp8
// float -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint8_t
,
float
>
(
const
float
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
return
float_to_fp8
(
a
/
scale
);
if
(
kv_type
==
vllm
::
Fp8KVCacheDataType
::
kFp8E4M3
)
{
return
float_to_fp8_e4m3
(
a
/
scale
);
}
else
{
return
float_to_fp8_e5m2
(
a
/
scale
);
}
}
}
// floatx2 -> fp8x2
// floatx2 -> fp8x2
template
<
>
template
<
>
__inline__
__device__
uint16_t
__inline__
__device__
uint16_t
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint16_t
,
float2
>
(
const
float2
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint8_t
ui8
[
2
];
uint8_t
ui8
[
2
];
uint16_t
ui16
;
uint16_t
ui16
;
}
tmp
;
}
tmp
;
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
);
tmp
.
ui8
[
0
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
x
,
scale
,
kv_type
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
);
tmp
.
ui8
[
1
]
=
scaled_vec_conversion
<
uint8_t
,
float
>
(
a
.
y
,
scale
,
kv_type
);
return
tmp
.
ui16
;
return
tmp
.
ui16
;
}
}
// floatx4 -> fp8x4
// floatx4 -> fp8x4
template
<
>
template
<
>
__inline__
__device__
uint32_t
__inline__
__device__
uint32_t
scaled_vec_conversion
<
uint32_t
,
float4
>
(
const
float4
&
a
,
float
scale
)
{
scaled_vec_conversion
<
uint32_t
,
float4
>
(
const
float4
&
a
,
float
scale
,
Fp8KVCacheDataType
kv_type
)
{
union
{
union
{
uint16_t
ui16
[
2
];
uint16_t
ui16
[
2
];
uint32_t
ui32
;
uint32_t
ui32
;
}
tmp
;
}
tmp
;
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
x
,
a
.
y
},
scale
);
tmp
.
ui16
[
0
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
x
,
a
.
y
},
scale
,
kv_type
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
);
tmp
.
ui16
[
1
]
=
scaled_vec_conversion
<
uint16_t
,
float2
>
({
a
.
z
,
a
.
w
},
scale
,
kv_type
);
return
tmp
.
ui32
;
return
tmp
.
ui32
;
}
}
...
@@ -433,8 +479,8 @@ scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
...
@@ -433,8 +479,8 @@ scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
)
{
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E4M3
||
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
)
{
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
,
kv_dt
);
}
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tout
)
==
1
){
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tout
)
==
1
){
return
scaled_vec_conversion_to_e5m2
<
Tin
>
(
x
,
scale
);
return
scaled_vec_conversion_to_e5m2
<
Tin
>
(
x
,
scale
);
...
...
vllm/attention/utils/fa_utils.py
View file @
b3062dab
...
@@ -5,6 +5,7 @@ from typing import Optional
...
@@ -5,6 +5,7 @@ from typing import Optional
from
vllm
import
envs
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
import
torch
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -68,6 +69,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
...
@@ -68,6 +69,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
def
flash_attn_supports_fp8
()
->
bool
:
def
flash_attn_supports_fp8
()
->
bool
:
if
current_platform
.
is_rocm
():
return
True
return
get_flash_attn_version
()
==
3
and
\
return
get_flash_attn_version
()
==
3
and
\
current_platform
.
get_device_capability
().
major
==
9
current_platform
.
get_device_capability
().
major
==
9
...
...
vllm/envs.py
View file @
b3062dab
...
@@ -149,6 +149,7 @@ if TYPE_CHECKING:
...
@@ -149,6 +149,7 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_TRITON_OPT_MLA
:
bool
=
False
VLLM_USE_FLASH_ATTN_FP8
:
bool
=
False
VLLM_USE_FLASH_ATTN_FP8
:
bool
=
False
VLLM_USE_QUERY_QUANT
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA_FP8
:
bool
=
False
VLLM_USE_FLASH_MLA_FP8
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
...
@@ -1071,6 +1072,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1071,6 +1072,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FLASH_ATTN_FP8"
:
"VLLM_USE_FLASH_ATTN_FP8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_ATTN_FP8"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_ATTN_FP8"
,
"1"
))),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_QUERY_QUANT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# If set, vLLM will use FLASH MLA attention optimizations.
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA"
:
"VLLM_USE_FLASH_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
b3062dab
...
@@ -136,6 +136,17 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -136,6 +136,17 @@ class FlashAttentionBackend(AttentionBackend):
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
raise
ValueError
(
f
"Unknown cache layout format
{
cache_layout
}
."
)
return
key_stride_order
,
value_stride_order
return
key_stride_order
,
value_stride_order
@
staticmethod
def
get_fp8_dtype_for_flashattn
(
kv_cache_dtype
:
str
)
->
torch
.
dtype
:
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
return
torch
.
float8_e4m3fn
else
:
raise
ValueError
(
f
"Unsupported FP8 dtype:
{
kv_cache_dtype
}
"
)
elif
kv_cache_dtype
in
(
"fp8_e5m2"
):
return
torch
.
float8_e5m2
else
:
raise
ValueError
(
f
"Unrecognized FP8 dtype:
{
kv_cache_dtype
}
"
)
@
dataclass
@
dataclass
class
FlashAttentionMetadata
:
class
FlashAttentionMetadata
:
...
@@ -589,14 +600,19 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -589,14 +600,19 @@ class FlashAttentionImpl(AttentionImpl):
)
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
torch
.
float8_e4m3fn
)
# key_cache = key_cache.view(torch.float8_e4m3fn)
value_cache
=
value_cache
.
view
(
torch
.
float8_e4m3fn
)
# value_cache = value_cache.view(torch.float8_e4m3fn)
num_tokens
,
num_heads
,
head_size
=
query
.
shape
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
query
,
_
=
ops
.
scaled_fp8_quant
(
self
.
kv_cache_dtype
)
query
.
reshape
(
key_cache
=
key_cache
.
view
(
dtype
)
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
value_cache
=
value_cache
.
view
(
dtype
)
layer
.
_q_scale
)
if
envs
.
VLLM_USE_QUERY_QUANT
:
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
num_tokens
,
num_heads
,
head_size
=
query
.
shape
query
,
_
=
ops
.
scaled_fp8_quant
(
query
.
reshape
(
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn
=
\
use_local_attn
=
\
...
@@ -620,9 +636,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -620,9 +636,10 @@ class FlashAttentionImpl(AttentionImpl):
block_table
=
attn_metadata
.
block_table
block_table
=
attn_metadata
.
block_table
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
#
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
if
not
current_platform
.
is_rocm
():
if
not
current_platform
.
is_rocm
():
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
flash_attn_varlen_func
(
flash_attn_varlen_func
(
q
=
query
[:
num_actual_tokens
],
q
=
query
[:
num_actual_tokens
],
k
=
key_cache
,
k
=
key_cache
,
...
@@ -672,6 +689,9 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -672,6 +689,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape),
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
# num_splits=attn_metadata.max_num_splits,
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache
=
True
,
is_prefix_cache
=
True
,
)
)
...
@@ -729,6 +749,9 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -729,6 +749,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
# v_descale=layer._v_scale,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
,
v_descale
=
layer
.
_v_scale
,
)
)
return
output
return
output
...
@@ -879,12 +902,12 @@ def cascade_attention(
...
@@ -879,12 +902,12 @@ def cascade_attention(
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
prefix_scheduler_metadata
,
scheduler_metadata
=
prefix_scheduler_metadata
,
# fa_version=fa_version,
# fa_version=fa_version,
#
q_descale=q_descale.expand(descale_shape)
q_descale
=
q_descale
.
expand
(
descale_shape
)
#
if q_descale is not None else None,
if
q_descale
is
not
None
else
None
,
#
k_descale=k_descale.expand(descale_shape)
k_descale
=
k_descale
.
expand
(
descale_shape
)
#
if k_descale is not None else None,
if
k_descale
is
not
None
else
None
,
#
v_descale=v_descale.expand(descale_shape)
v_descale
=
v_descale
.
expand
(
descale_shape
)
#
if v_descale is not None else None,
if
v_descale
is
not
None
else
None
,
is_prefix_cache
=
True
,
is_prefix_cache
=
True
,
)
)
...
@@ -932,12 +955,12 @@ def cascade_attention(
...
@@ -932,12 +955,12 @@ def cascade_attention(
return_softmax_lse
=
True
,
return_softmax_lse
=
True
,
scheduler_metadata
=
suffix_scheduler_metadata
,
scheduler_metadata
=
suffix_scheduler_metadata
,
# fa_version=fa_version,
# fa_version=fa_version,
#
q_descale=q_descale.expand(descale_shape)
q_descale
=
q_descale
.
expand
(
descale_shape
)
#
if q_descale is not None else None,
if
q_descale
is
not
None
else
None
,
#
k_descale=k_descale.expand(descale_shape)
k_descale
=
k_descale
.
expand
(
descale_shape
)
#
if k_descale is not None else None,
if
k_descale
is
not
None
else
None
,
#
v_descale=v_descale.expand(descale_shape)
v_descale
=
v_descale
.
expand
(
descale_shape
)
#
if v_descale is not None else None,
if
v_descale
is
not
None
else
None
,
is_prefix_cache
=
True
,
is_prefix_cache
=
True
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment