Commit ac28ab22 authored by zhuwenwen's avatar zhuwenwen
Browse files

[perf] add VLLM_USE_FLASH_ATTN_FP8 to use fa fp8 attention

parent 5fe03549
...@@ -258,6 +258,7 @@ if TYPE_CHECKING: ...@@ -258,6 +258,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_URLS_PORT: int | None = None VLLM_OPTEST_URLS_PORT: int | None = None
VLLM_OPTEST_MODELS_PATH: str = "" VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_FLASH_ATTN_FP8: bool = False
VLLM_USE_QUERY_QUANT: bool = False VLLM_USE_QUERY_QUANT: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
...@@ -1685,6 +1686,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1685,6 +1686,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "0"))),
# flag to control if vllm should use q quant # flag to control if vllm should use q quant
"VLLM_USE_QUERY_QUANT": "VLLM_USE_QUERY_QUANT":
lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in lambda: (os.environ.get("VLLM_USE_QUERY_QUANT", "False").lower() in
......
...@@ -1499,6 +1499,34 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1499,6 +1499,34 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
def _run_prefill_new_tokens_fa( def _run_prefill_new_tokens_fa(
self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
): ):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.query_start_loc,
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=return_softmax_lse,
)
else:
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1558,6 +1586,34 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -1558,6 +1586,34 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
): ):
assert prefill.chunked_context is not None assert prefill.chunked_context is not None
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill.query_start_loc,
cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
max_seqlen_q=prefill.max_query_len,
max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
else:
return self._flash_attn_varlen_diff_headdims( return self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment