Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69a8c8e9
Unverified
Commit
69a8c8e9
authored
Sep 25, 2025
by
Jonas M. Kübler
Committed by
GitHub
Sep 25, 2025
Browse files
[torch.compile] Make Query Quantization Fusable (#24914)
Signed-off-by:
Jonas Kuebler
<
kuebj@amazon.com
>
parent
6c340da4
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
8 deletions
+32
-8
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+8
-0
vllm/attention/layer.py
vllm/attention/layer.py
+22
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-7
No files found.
vllm/attention/backends/abstract.py
View file @
69a8c8e9
...
@@ -31,6 +31,14 @@ class AttentionBackend(ABC):
...
@@ -31,6 +31,14 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer
:
bool
=
False
accept_output_buffer
:
bool
=
False
# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input
:
bool
=
False
@
staticmethod
@
staticmethod
@
abstractmethod
@
abstractmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
...
...
vllm/attention/layer.py
View file @
69a8c8e9
...
@@ -22,7 +22,10 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...
@@ -22,7 +22,10 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.linear
import
UnquantizedLinearMethod
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
)
from
vllm.model_executor.models.vision
import
get_vit_attn_backend
from
vllm.model_executor.models.vision
import
get_vit_attn_backend
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
GiB_bytes
,
direct_register_custom_op
from
vllm.utils
import
GiB_bytes
,
direct_register_custom_op
...
@@ -247,6 +250,13 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -247,6 +250,13 @@ class Attention(nn.Module, AttentionLayerBase):
"This may be caused by insufficient memory to allocate "
"This may be caused by insufficient memory to allocate "
"kv cache."
)
from
e
"kv cache."
)
from
e
# for attn backends supporting query quantization
self
.
query_quant
=
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
self
.
attn_backend
.
supports_quant_query_input
:
self
.
query_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -270,11 +280,22 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -270,11 +280,22 @@ class Attention(nn.Module, AttentionLayerBase):
attn_metadata
=
get_forward_context
().
attn_metadata
attn_metadata
=
get_forward_context
().
attn_metadata
if
attn_metadata
.
enable_kv_scales_calculation
:
if
attn_metadata
.
enable_kv_scales_calculation
:
self
.
calc_kv_scales
(
query
,
key
,
value
)
self
.
calc_kv_scales
(
query
,
key
,
value
)
output_dtype
=
query
.
dtype
if
self
.
query_quant
is
not
None
:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert
self
.
kv_cache_dtype
in
{
"fp8"
,
"fp8_e4m3"
}
query
,
_
=
self
.
query_quant
(
query
,
self
.
_q_scale
)
if
self
.
use_output
:
if
self
.
use_output
:
output_shape
=
(
output_shape
output_shape
=
(
output_shape
if
output_shape
is
not
None
else
query
.
shape
)
if
output_shape
is
not
None
else
query
.
shape
)
output
=
torch
.
zeros
(
output_shape
,
output
=
torch
.
zeros
(
output_shape
,
dtype
=
query
.
dtype
,
dtype
=
output_
dtype
,
device
=
query
.
device
)
device
=
query
.
device
)
hidden_size
=
output_shape
[
-
1
]
hidden_size
=
output_shape
[
-
1
]
# We skip reshaping query, key and value tensors for the MLA
# We skip reshaping query, key and value tensors for the MLA
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
69a8c8e9
...
@@ -7,7 +7,6 @@ from typing import Optional
...
@@ -7,7 +7,6 @@ from typing import Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
AttentionMetadata
,
AttentionType
,
...
@@ -38,6 +37,7 @@ logger = init_logger(__name__)
...
@@ -38,6 +37,7 @@ logger = init_logger(__name__)
class
FlashAttentionBackend
(
AttentionBackend
):
class
FlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
accept_output_buffer
:
bool
=
True
supports_quant_query_input
:
bool
=
True
@
classmethod
@
classmethod
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
def
get_supported_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
...
@@ -506,16 +506,11 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -506,16 +506,11 @@ class FlashAttentionImpl(AttentionImpl):
)
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
# queries are quantized in the attention layer
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
self
.
kv_cache_dtype
)
self
.
kv_cache_dtype
)
key_cache
=
key_cache
.
view
(
dtype
)
key_cache
=
key_cache
.
view
(
dtype
)
value_cache
=
value_cache
.
view
(
dtype
)
value_cache
=
value_cache
.
view
(
dtype
)
num_tokens
,
num_heads
,
head_size
=
query
.
shape
query
,
_
=
ops
.
scaled_fp8_quant
(
query
.
reshape
(
(
num_tokens
,
num_heads
*
head_size
)).
contiguous
(),
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
if
not
attn_metadata
.
use_cascade
:
if
not
attn_metadata
.
use_cascade
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
cu_seqlens_q
=
attn_metadata
.
query_start_loc
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment