Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8bd8db7
Commit
c8bd8db7
authored
Jan 15, 2026
by
zhuwenwen
Browse files
support fa kvcache fp8
todo: add VLLM_USE_QUERY_QUANT to not use q quant
parent
2a75c6bc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
5 deletions
+17
-5
vllm/attention/layer.py
vllm/attention/layer.py
+6
-4
vllm/attention/utils/fa_utils.py
vllm/attention/utils/fa_utils.py
+4
-1
vllm/envs.py
vllm/envs.py
+7
-0
No files found.
vllm/attention/layer.py
View file @
c8bd8db7
...
@@ -255,10 +255,12 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -255,10 +255,12 @@ class Attention(nn.Module, AttentionLayerBase):
# for attn backends supporting query quantization
# for attn backends supporting query quantization
self
.
query_quant
=
None
self
.
query_quant
=
None
if
self
.
kv_cache_dtype
.
startswith
(
# @TODO
"fp8"
)
and
self
.
attn_backend
.
supports_quant_query_input
:
if
envs
.
VLLM_USE_QUERY_QUANT
:
self
.
query_quant
=
QuantFP8
(
static
=
True
,
if
self
.
kv_cache_dtype
.
startswith
(
group_shape
=
GroupShape
.
PER_TENSOR
)
"fp8"
)
and
self
.
attn_backend
.
supports_quant_query_input
:
self
.
query_quant
=
QuantFP8
(
static
=
True
,
group_shape
=
GroupShape
.
PER_TENSOR
)
def
forward
(
def
forward
(
self
,
self
,
...
...
vllm/attention/utils/fa_utils.py
View file @
c8bd8db7
...
@@ -5,6 +5,7 @@ from typing import Optional
...
@@ -5,6 +5,7 @@ from typing import Optional
from
vllm
import
envs
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
import
torch
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -61,13 +62,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
...
@@ -61,13 +62,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
logger
.
error
(
"Cannot use FA version %d is not supported due to %s"
,
fa_version
,
fa_version_unsupported_reason
(
fa_version
))
fa_version
,
fa_version_unsupported_reason
(
fa_version
))
assert
is_fa_version_supported
(
fa_version
)
assert
is_fa_version_supported
(
fa_version
)
+
12
return
fa_version
return
fa_version
except
(
ImportError
,
AssertionError
):
except
(
ImportError
,
AssertionError
):
return
None
return
None
def
flash_attn_supports_fp8
()
->
bool
:
def
flash_attn_supports_fp8
()
->
bool
:
if
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
==
"gfx938"
:
return
True
return
get_flash_attn_version
()
==
3
and
\
return
get_flash_attn_version
()
==
3
and
\
current_platform
.
get_device_capability
().
major
==
9
current_platform
.
get_device_capability
().
major
==
9
...
...
vllm/envs.py
View file @
c8bd8db7
...
@@ -210,6 +210,7 @@ if TYPE_CHECKING:
...
@@ -210,6 +210,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_URLS_PORT
:
Optional
[
int
]
=
None
VLLM_OPTEST_URLS_PORT
:
Optional
[
int
]
=
None
VLLM_OPTEST_MODELS_PATH
:
str
=
""
VLLM_OPTEST_MODELS_PATH
:
str
=
""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_TRITON_PREFIX_FLASH_ATTN
:
bool
=
False
VLLM_USE_QUERY_QUANT
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_FLASH_MLA
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_OPT_OP
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
VLLM_USE_PA_PRINT_PARAM
:
bool
=
False
...
@@ -1534,6 +1535,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1534,6 +1535,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_PREFIX_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_QUERY_QUANT"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# If set, vLLM will use FLASH MLA attention optimizations.
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA"
:
"VLLM_USE_FLASH_MLA"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASH_MLA"
,
"1"
))),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment