Commit c8bd8db7 authored by zhuwenwen's avatar zhuwenwen
Browse files

support fa kvcache fp8

todo: add VLLM_USE_QUERY_QUANT to not use q quant
parent 2a75c6bc
...@@ -255,10 +255,12 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -255,10 +255,12 @@ class Attention(nn.Module, AttentionLayerBase):
# for attn backends supporting query quantization # for attn backends supporting query quantization
self.query_quant = None self.query_quant = None
if self.kv_cache_dtype.startswith( # @TODO
"fp8") and self.attn_backend.supports_quant_query_input: if envs.VLLM_USE_QUERY_QUANT:
self.query_quant = QuantFP8(static=True, if self.kv_cache_dtype.startswith(
group_shape=GroupShape.PER_TENSOR) "fp8") and self.attn_backend.supports_quant_query_input:
self.query_quant = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
def forward( def forward(
self, self,
......
...@@ -5,6 +5,7 @@ from typing import Optional ...@@ -5,6 +5,7 @@ from typing import Optional
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
import torch
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -61,13 +62,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: ...@@ -61,13 +62,15 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
logger.error("Cannot use FA version %d is not supported due to %s", logger.error("Cannot use FA version %d is not supported due to %s",
fa_version, fa_version_unsupported_reason(fa_version)) fa_version, fa_version_unsupported_reason(fa_version))
assert is_fa_version_supported(fa_version) assert is_fa_version_supported(fa_version)+12
return fa_version return fa_version
except (ImportError, AssertionError): except (ImportError, AssertionError):
return None return None
def flash_attn_supports_fp8() -> bool: def flash_attn_supports_fp8() -> bool:
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
return get_flash_attn_version() == 3 and \ return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9 current_platform.get_device_capability().major == 9
......
...@@ -210,6 +210,7 @@ if TYPE_CHECKING: ...@@ -210,6 +210,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_URLS_PORT: Optional[int] = None VLLM_OPTEST_URLS_PORT: Optional[int] = None
VLLM_OPTEST_MODELS_PATH: str = "" VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
...@@ -1534,6 +1535,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1534,6 +1535,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
("true", "1")),
# If set, vLLM will use FLASH MLA attention optimizations. # If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA": "VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment