Commit 43546076 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_FLASH_ATTN_FP8 to support fa fp8

parent 1663f34c
......@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'):
version = 'das.opt2.' + sha[:7]
version = 'das.opt3.' + sha[:7]
else:
if (major, minor) >= ('2', '5'):
version = 'das.opt2'
version = 'das.opt3'
# dtk version
......
......@@ -145,6 +145,7 @@ if TYPE_CHECKING:
VLLM_OPTEST_MODELS_PATH: str = ""
VLLM_USE_TRITON_PREFIX_FLASH_ATTN: bool = False
VLLM_USE_TRITON_OPT_MLA: bool = False
VLLM_USE_FLASH_ATTN_FP8: bool = False
VLLM_USE_FLASH_MLA: bool = False
VLLM_USE_FLASH_MLA_FP8: bool = False
VLLM_USE_OPT_OP: bool = False
......@@ -1038,6 +1039,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_OPT_MLA":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))),
# If set, vLLM will use FLASH ATTN fp8 attention optimizations.
"VLLM_USE_FLASH_ATTN_FP8":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_ATTN_FP8", "0"))),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "1"))),
......
......@@ -828,6 +828,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q,
k,
v,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=False,
softmax_scale=None,
**kwargs):
......@@ -850,6 +853,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q=q,
k=k,
v=maybe_padded_v,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
softmax_scale=softmax_scale,
**kwargs,
)
......@@ -978,6 +984,35 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
......@@ -989,6 +1024,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=True,
)
......@@ -1043,6 +1081,34 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and envs.VLLM_USE_FLASH_ATTN_FP8:
q_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
k_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
v_descale = torch.tensor(1.0, dtype=torch.float32, device=q.device)
descale_shape = (attn_metadata.prefill.query_start_loc.numel() - 1, q.shape[1])
q_descale = q_descale.expand(descale_shape)
k_descale = k_descale.expand(descale_shape)
v_descale = v_descale.expand(descale_shape)
q = q.to(torch.float8_e4m3fn)
k = k.to(torch.float8_e4m3fn)
v = v.to(torch.float8_e4m3fn)
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=has_context,
)
else:
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
......@@ -1053,6 +1119,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale,
causal=True,
q_descale=None,
k_descale=None,
v_descale=None,
return_softmax_lse=has_context,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment