Commit 43546076 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_FLASH_ATTN_FP8 to support fa fp8

parent 1663f34c
...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt2.' + sha[:7] version = 'das.opt3.' + sha[:7]
else: else:
if (major, minor) >= ('2', '5'): if (major, minor) >= ('2', '5'):
version = 'das.opt2' version = 'das.opt3'
# dtk version # dtk version
......
...@@ -145,6 +145,7 @@ if TYPE_CHECKING: ...@@ -145,6 +145,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_MODELS_PATH: str = "" VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_ATTN_FP8: bool = False
VLLM_USE_FLASH_MLA: bool = False VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_OPT_MLA": "VLLM_USE_TRITON_OPT_MLA":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))), lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))),
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "0"))),
# If set, vLLM will use FLASH MLA attention optimizations. # If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA": "VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))), lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......
...@@ -828,6 +828,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -828,6 +828,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q, q,
k, k,
v, v,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=False, return_softmax_lse=False,
softmax_scale=None, softmax_scale=None,
**kwargs): **kwargs):
...@@ -850,6 +853,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -850,6 +853,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q=q, q=q,
k=k, k=k,
v=maybe_padded_v, v=maybe_padded_v,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
**kwargs, **kwargs,
) )
...@@ -978,19 +984,51 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -978,19 +984,51 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1) dim=-1)
attn_output, attn_softmax_lse = \ if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
self._flash_attn_varlen_diff_headdims( q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
q=q, k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k=k, v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v=v, descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
cu_seqlens_q=prefill_metadata.query_start_loc, q_descale = q_descale.expand(descale_shape)
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], k_descale = k_descale.expand(descale_shape)
max_seqlen_q=prefill_metadata.max_query_len, v_descale = v_descale.expand(descale_shape)
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], q = q.to(torch.float8_e4m3fn)
softmax_scale=self.scale, k = k.to(torch.float8_e4m3fn)
causal=False, # Context is unmasked v = v.to(torch.float8_e4m3fn)
return_softmax_lse=True,
) attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=True,
)
if output is None: if output is None:
output = attn_output output = attn_output
...@@ -1043,18 +1081,49 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1043,18 +1081,49 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else: else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims( if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q=q, q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k=k, k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v=v, v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
cu_seqlens_q=attn_metadata.prefill.query_start_loc, descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
cu_seqlens_k=attn_metadata.prefill.query_start_loc, q_descale = q_descale.expand(descale_shape)
max_seqlen_q=attn_metadata.prefill.max_query_len, k_descale = k_descale.expand(descale_shape)
max_seqlen_k=attn_metadata.prefill.max_query_len, v_descale = v_descale.expand(descale_shape)
softmax_scale=self.scale,
causal=True, q = q.to(torch.float8_e4m3fn)
return_softmax_lse=has_context, k = k.to(torch.float8_e4m3fn)
) v = v.to(torch.float8_e4m3fn)
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=has_context,
)
else:
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=has_context,
)
if has_context: if has_context:
suffix_output, suffix_lse = output suffix_output, suffix_lse = output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment