Commit 2e1f5a46 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev

parents 8ba8a855 1e622f10
......@@ -266,7 +266,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).to(torch.float8_e4m3fn), # Add head dim of 1
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
......@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
......
......@@ -101,12 +101,13 @@ class Attention(nn.Module):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
self._k_scale = torch.ones((1), dtype=torch.float32)
self._v_scale = torch.ones((1), dtype=torch.float32)
self._q_scale = torch.ones((1), dtype=torch.float32)
else:
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention
......
......@@ -101,6 +101,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
k_scale = None,
kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
......@@ -145,7 +147,6 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
......@@ -153,6 +154,8 @@ def flash_mla_with_kvcache(
causal,
tile_scheduler_metadata,
num_splits,
is_fp8_kvcache,
indices,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
......
......@@ -194,7 +194,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).to(torch.float8_e4m3fn), # Add head dim of 1
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
......@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment