Commit 43546076 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_FLASH_ATTN_FP8 to support fa fp8

parent 1663f34c
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt2.' + sha[:7] version = 'das.opt3.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt2' version = 'das.opt3'
# dtk version # dtk version
......
...@@ -145,6 +145,7 @@ if TYPE_CHECKING: ...@@ -145,6 +145,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_MODELS_PATH: str = "" VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_ATTN_FP8: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_OPT_MLA": "VLLM_USE_TRITON_OPT_MLA":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))), lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))),
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "0"))),
# If set, vLLM will use FLASH MLA attention optimizations. # If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA": "VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......
...@@ -828,6 +828,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -828,6 +828,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q, q,
k, k,
v, v,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=False, return_softmax_lse=False,
softmax_scale=None, softmax_scale=None,
**kwargs): **kwargs):
...@@ -850,6 +853,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -850,6 +853,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=maybe_padded_v, v=maybe_padded_v,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
**kwargs, **kwargs,
) )
...@@ -978,6 +984,35 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -978,6 +984,35 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse = \ attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims( self._flash_attn_varlen_diff_headdims(
q=q, q=q,
...@@ -989,6 +1024,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -989,6 +1024,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, # Context is unmasked causal=False, # Context is unmasked
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=True, return_softmax_lse=True,
) )
...@@ -1043,6 +1081,34 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1043,6 +1081,34 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=has_context,
)
else:
output = self._flash_attn_varlen_diff_headdims( output = self._flash_attn_varlen_diff_headdims(
q=q, q=q,
k=k, k=k,
...@@ -1053,6 +1119,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1053,6 +1119,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k=attn_metadata.prefill.max_query_len, max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=has_context, return_softmax_lse=has_context,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment