Commit 2e1f5a46 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev

parents 8ba8a855 1e622f10
...@@ -266,7 +266,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -266,7 +266,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8( o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn), q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).to(torch.float8_e4m3fn), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor, cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
...@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -288,6 +288,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=decode_meta.decode_num_splits, num_splits=decode_meta.decode_num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale, k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype, kv_cache_dtype = kv_cache_dtype,
) )
......
...@@ -101,12 +101,13 @@ class Attention(nn.Module): ...@@ -101,12 +101,13 @@ class Attention(nn.Module):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
self._k_scale = torch.ones((1), dtype=torch.float32) self._k_scale = torch.ones((1), dtype=torch.float32)
self._v_scale = torch.ones((1), dtype=torch.float32) self._v_scale = torch.ones((1), dtype=torch.float32)
self._q_scale = torch.ones((1), dtype=torch.float32)
else: else:
self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32)
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only # FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well. # but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32) self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep the float32 versions of k/v_scale for attention # We also keep the float32 versions of k/v_scale for attention
......
...@@ -101,6 +101,8 @@ def flash_mla_with_kvcache( ...@@ -101,6 +101,8 @@ def flash_mla_with_kvcache(
num_splits: torch.Tensor, num_splits: torch.Tensor,
softmax_scale: Optional[float] = None, softmax_scale: Optional[float] = None,
causal: bool = False, causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
k_scale = None, k_scale = None,
kv_cache_dtype = "auto", kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -145,7 +147,6 @@ def flash_mla_with_kvcache( ...@@ -145,7 +147,6 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, q,
k_cache, k_cache,
None,
head_dim_v, head_dim_v,
cache_seqlens, cache_seqlens,
block_table, block_table,
...@@ -153,6 +154,8 @@ def flash_mla_with_kvcache( ...@@ -153,6 +154,8 @@ def flash_mla_with_kvcache(
causal, causal,
tile_scheduler_metadata, tile_scheduler_metadata,
num_splits, num_splits,
is_fp8_kvcache,
indices,
) )
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
......
...@@ -194,7 +194,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -194,7 +194,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8( o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn), q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).to(torch.float8_e4m3fn), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).view(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
...@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -232,6 +232,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits=attn_metadata.decode.num_splits, num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
is_fp8_kvcache=False,
indices= None,
k_scale = k_scale, k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype, kv_cache_dtype = kv_cache_dtype,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment