Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b2d14ba3
Commit
b2d14ba3
authored
Sep 28, 2025
by
yangql
Browse files
修复kvcache-fp8—e5m2的不能开cp的bug
parent
bb13d854
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
111 additions
and
26 deletions
+111
-26
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+46
-9
vllm/_custom_ops.py
vllm/_custom_ops.py
+7
-4
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+4
-3
vllm/attention/layer.py
vllm/attention/layer.py
+45
-3
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+3
-2
vllm/utils/__init__.py
vllm/utils/__init__.py
+2
-2
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+4
-3
No files found.
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
b2d14ba3
...
@@ -365,21 +365,21 @@ inline __device__ uint8_t float_to_fp8e5m2(float f) {
...
@@ -365,21 +365,21 @@ inline __device__ uint8_t float_to_fp8e5m2(float f) {
// fp8
// fp8
template
<
typename
Tin
>
template
<
typename
Tin
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion_e5m2
(
const
Tin
&
a
,
float
scale
)
{
scaled_vec_conversion_
to_
e5m2
(
const
Tin
&
a
,
float
scale
)
{
return
0
;
return
0
;
}
}
// float -> fp8
// float -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion_e5m2
<
float
>
(
const
float
&
a
,
float
scale
)
{
scaled_vec_conversion_
to_
e5m2
<
float
>
(
const
float
&
a
,
float
scale
)
{
return
float_to_fp8e5m2
(
a
/
scale
);
return
float_to_fp8e5m2
(
a
/
scale
);
}
}
// half -> fp8
// half -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion_e5m2
<
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
scaled_vec_conversion_
to_
e5m2
<
uint16_t
>
(
const
uint16_t
&
a
,
float
scale
)
{
float
res_f
=
half_to_float
(
a
)
/
scale
;
float
res_f
=
half_to_float
(
a
)
/
scale
;
return
float_to_fp8e5m2
(
res_f
);
return
float_to_fp8e5m2
(
res_f
);
}
}
...
@@ -387,11 +387,49 @@ scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) {
...
@@ -387,11 +387,49 @@ scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) {
// bf16 -> fp8
// bf16 -> fp8
template
<
>
template
<
>
__inline__
__device__
uint8_t
__inline__
__device__
uint8_t
scaled_vec_conversion_e5m2
<
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
scaled_vec_conversion_
to_
e5m2
<
__nv_bfloat16
>
(
const
__nv_bfloat16
&
a
,
float
scale
)
{
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
float
res_f
=
(
static_cast
<
float
>
(
a
))
/
scale
;
return
float_to_fp8e5m2
(
res_f
);
return
float_to_fp8e5m2
(
res_f
);
}
}
inline
__device__
float
fp8e5m2_to_fp32
(
const
uint8_t
&
input
)
{
union
uf16
{
uint16_t
as_bits
;
_Float16
as_value
;
}
;
uf16
u16
;
u16
.
as_bits
=
(
uint16_t
)
input
<<
8
;
return
(
float
)
u16
.
as_value
;
}
template
<
typename
Tout
>
__inline__
__device__
Tout
scaled_vec_conversion_from_e5m2
(
const
uint8_t
&
a
,
float
scale
)
{
return
0
;
}
// fp8 -> float
template
<
>
__inline__
__device__
float
scaled_vec_conversion_from_e5m2
<
float
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
fp8e5m2_to_fp32
(
a
)
*
scale
;
}
// fp8 -> half
template
<
>
__inline__
__device__
uint16_t
scaled_vec_conversion_from_e5m2
<
uint16_t
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
float_to_half
(
fp8e5m2_to_fp32
(
a
)
*
scale
);
}
// fp8 -> bf16
template
<
>
__inline__
__device__
__nv_bfloat16
scaled_vec_conversion_from_e5m2
<
__nv_bfloat16
>
(
const
uint8_t
&
a
,
float
scale
)
{
return
__float2bfloat16
(
fp8e5m2_to_fp32
(
a
)
*
scale
);
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
__inline__
__device__
Tout
scaled_convert
(
const
Tin
&
x
,
const
float
scale
)
{
...
@@ -399,12 +437,11 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
...
@@ -399,12 +437,11 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
return
scaled_vec_conversion
<
Tout
,
Tin
>
(
x
,
scale
);
}
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tout
)
==
1
){
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tout
)
==
1
){
return
scaled_vec_conversion_e5m2
<
Tin
>
(
x
,
scale
);
return
scaled_vec_conversion_to_e5m2
<
Tin
>
(
x
,
scale
);
}
else
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kFp8E5M2
&&
sizeof
(
Tin
)
==
1
){
return
scaled_vec_conversion_from_e5m2
<
Tout
>
(
x
,
scale
);
}
}
// else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 &&
// (std::is_same<Tin, uint16_t>::value||std::is_same<Tin, __nv_bfloat16>::value)){
// return scaled_vec_conversion_e5m2<Tin>(x, scale);
// }
return
{};
// Squash missing return statement warning
return
{};
// Squash missing return statement warning
}
}
...
...
vllm/_custom_ops.py
View file @
b2d14ba3
...
@@ -2166,12 +2166,15 @@ def gather_cache(src_cache: torch.Tensor,
...
@@ -2166,12 +2166,15 @@ def gather_cache(src_cache: torch.Tensor,
kv_dtype
=
"auto"
,
kv_dtype
=
"auto"
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
)
->
None
:
)
->
None
:
#支持"kv cache fp8"
#支持"kv cache fp8"
临时方案,带dtype的gather_cache在vllm0.10后会实现。
if
kv_dtype
==
"fp8"
:
if
kv_dtype
==
"fp8"
or
kv_dtype
==
"fp8_e5m2"
or
kv_dtype
==
"fp8_e4m3"
:
dst_fp8
=
torch
.
zeros
(
dst
.
shape
,
dtype
=
torch
.
uint8
,
device
=
dst
.
device
)
dst_fp8
=
torch
.
empty
(
dst
.
shape
,
dtype
=
torch
.
uint8
,
device
=
dst
.
device
)
convert_fp8
(
dst_fp8
,
dst
,
scale
,
kv_dtype
)
#
convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst_fp8
,
block_table
,
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst_fp8
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
cu_seq_lens
,
batch_size
,
seq_starts
)
#dst_fp8->bf16
convert_fp8
(
dst
,
dst_fp8
,
scale
,
kv_dtype
)
else
:
else
:
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
torch
.
ops
.
_C_cache_ops
.
gather_cache
(
src_cache
,
dst
,
block_table
,
cu_seq_lens
,
batch_size
,
seq_starts
)
cu_seq_lens
,
batch_size
,
seq_starts
)
...
...
vllm/attention/backends/flashmla.py
View file @
b2d14ba3
...
@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -211,9 +211,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
self
.
kv_cache_dtype
!=
"fp8"
:
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
raise
NotImplementedError
(
return
"FlashMLA with other KV cache not yet supported"
)
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
...
...
vllm/attention/layer.py
View file @
b2d14ba3
...
@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform
...
@@ -24,6 +24,11 @@ from vllm.platforms import _Backend, current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
USE_XFORMERS_OPS
=
None
try
:
tag_cudagraph_unsafe
=
(
torch
.
_C
.
Tag
.
cudagraph_unsafe
,
)
except
AttributeError
:
tag_cudagraph_unsafe
=
()
# type: ignore[assignment]
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""Attention layer.
"""Attention layer.
...
@@ -204,10 +209,12 @@ class Attention(nn.Module):
...
@@ -204,10 +209,12 @@ class Attention(nn.Module):
`vllm.forward_context.get_forward_context().attn_metadata`.
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
"""
if
self
.
calculate_kv_scales
:
if
self
.
calculate_kv_scales
:
attn_metadata
=
get_forward_context
().
attn_metadata
#
attn_metadata = get_forward_context().attn_metadata
if
(
attn_metadata
is
not
None
and
getattr
(
attn_metadata
,
"enable_kv_scales_calculation"
,
False
)):
# #
if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
# if key is not None and value is not None:
self
.
calc_kv_scales
(
query
,
key
,
value
)
# self.calc_kv_scales(query, key, value)
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
self
.
layer_name
)
if
self
.
use_output
:
if
self
.
use_output
:
output_shape
=
(
output_shape
output_shape
=
(
output_shape
if
output_shape
is
not
None
else
query
.
shape
)
if
output_shape
is
not
None
else
query
.
shape
)
...
@@ -395,7 +402,42 @@ def maybe_save_kv_layer_to_connector(
...
@@ -395,7 +402,42 @@ def maybe_save_kv_layer_to_connector(
assert
isinstance
(
attn_metadata
,
dict
)
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
attn_metadata
[
layer_name
])
def
maybe_calc_kv_scales
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
forward_context
:
ForwardContext
=
get_forward_context
()
attn_metadata
=
forward_context
.
attn_metadata
if
isinstance
(
attn_metadata
,
dict
):
attn_metadata
=
attn_metadata
[
layer_name
]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
.
calc_kv_scales
(
query
,
key
,
value
)
def
maybe_calc_kv_scales_fake
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
None
:
return
direct_register_custom_op
(
op_name
=
"maybe_calc_kv_scales"
,
op_func
=
maybe_calc_kv_scales
,
mutates_args
=
[],
fake_impl
=
maybe_calc_kv_scales_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
tag_cudagraph_unsafe
,)
def
unified_attention
(
def
unified_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
vllm/attention/ops/flashmla.py
View file @
b2d14ba3
...
@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
...
@@ -99,7 +99,8 @@ def flash_mla_with_kvcache(
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
kv_dtype
=
"fp8_e4m3"
if
kv_cache_dtype
==
"fp8"
else
kv_cache_dtype
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_quantization_mla
(
q
,
q
,
k_cache
,
k_cache
,
...
@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
...
@@ -112,7 +113,7 @@ def flash_mla_with_kvcache(
tile_scheduler_metadata
,
tile_scheduler_metadata
,
num_splits
,
num_splits
,
k_scale
,
k_scale
,
"fp8_e4m3"
,
kv_dtype
,
)
)
return
out
,
softmax_lse
return
out
,
softmax_lse
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
...
...
vllm/utils/__init__.py
View file @
b2d14ba3
...
@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
...
@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16"
:
torch
.
bfloat16
,
"bfloat16"
:
torch
.
bfloat16
,
"float"
:
torch
.
float
,
"float"
:
torch
.
float
,
"fp8"
:
torch
.
uint8
,
"fp8"
:
torch
.
uint8
,
#
"fp8_e4m3": torch.uint8,
"fp8_e4m3"
:
torch
.
uint8
,
#
"fp8_e5m2": torch.uint8,
"fp8_e5m2"
:
torch
.
uint8
,
"int8"
:
torch
.
int8
,
"int8"
:
torch
.
int8
,
}
}
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
b2d14ba3
...
@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
self
.
kv_cache_dtype
!=
"fp8"
:
if
kv_cache_dtype
==
"fp8"
or
kv_cache_dtype
==
"fp8_e4m3"
or
kv_cache_dtype
==
"fp8_e5m2"
:
raise
NotImplementedError
(
return
"FlashMLA with other KV cache not yet supported"
)
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment