Commit c8bd8db7 authored by zhuwenwen's avatar zhuwenwen
Browse files

support fa kvcache fp8

todo: add VLLM_USE_QUERY_QUANT to not use q quant
parent 2a75c6bc
......@@ -255,10 +255,12 @@ class Attention(nn.Module, AttentionLayerBase):
# for attn backends supporting query quantization
self.query_quant = None
if self.kv_cache_dtype.startswith(
"fp8") and self.attn_backend.supports_quant_query_input:
self.query_quant = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
# @TODO
if envs.VLLM_USE_QUERY_QUANT:
if self.kv_cache_dtype.startswith(
"fp8") and self.attn_backend.supports_quant_query_input:
self.query_quant = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
def forward(
self,
......
......@@ -5,6 +5,7 @@ from typing import Optional
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
import torch
logger = init_logger(__name__)
......@@ -61,13 +62,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version)
assert is_fa_version_supported(fa_version)+12
return fa_version
except (ImportError, AssertionError):
return None
def flash_attn_supports_fp8() -> bool:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9
......
......@@ -210,6 +210,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_URLS_PORT: Optional[int] = None
VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
......@@ -1534,6 +1535,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
("true", "1")),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment