Commit 8b7daa0d authored by yangql's avatar yangql
Browse files

修复kvcachefp8与cp/pc的冲突

parent 5c288d91
...@@ -2162,9 +2162,19 @@ def gather_cache(src_cache: torch.Tensor, ...@@ -2162,9 +2162,19 @@ def gather_cache(src_cache: torch.Tensor,
block_table: torch.Tensor, block_table: torch.Tensor,
cu_seq_lens: torch.Tensor, cu_seq_lens: torch.Tensor,
batch_size: int, batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None: seq_starts: Optional[torch.Tensor] = None,
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, kv_dtype = "auto",
cu_seq_lens, batch_size, seq_starts) scale: float = 1.0,
) -> None:
#支持"kv cache fp8"
if kv_dtype == "fp8":
dst_fp8 = torch.zeros(dst.shape, dtype=torch.uint8, device=dst.device)
convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
def get_device_attribute(attribute: int, device: int) -> int: def get_device_attribute(attribute: int, device: int) -> int:
......
...@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q: torch.Tensor, q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
): ):
prefill_metadata = attn_metadata.prefill_metadata prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None assert prefill_metadata is not None
...@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i], cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
batch_size=prefill_metadata.num_prefills, batch_size=prefill_metadata.num_prefills,
seq_starts=prefill_metadata.context_chunk_starts[i], seq_starts=prefill_metadata.context_chunk_starts[i],
kv_dtype=self.kv_cache_dtype,
scale=kv_scale,
) )
kv_c_normed = workspace[:toks]\ kv_c_normed = workspace[:toks]\
...@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe: torch.Tensor, k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
) -> torch.Tensor: ) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata prefill_metadata = attn_metadata.prefill_metadata
...@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2 # ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata) q, kv_c_and_k_pe_cache, attn_metadata, kv_scale)
output = torch.empty_like(suffix_output) output = torch.empty_like(suffix_output)
merge_attn_states( merge_attn_states(
...@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_prefill: if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill( output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata) attn_metadata, kv_scale=layer._k_scale)
if has_decode: if has_decode:
decode_q_nope, decode_q_pe = decode_q.split( decode_q_nope, decode_q_pe = decode_q.split(
......
...@@ -204,7 +204,8 @@ class Attention(nn.Module): ...@@ -204,7 +204,8 @@ class Attention(nn.Module):
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation: if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
self.calc_kv_scales(query, key, value) self.calc_kv_scales(query, key, value)
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
......
...@@ -894,6 +894,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -894,6 +894,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q: torch.Tensor, q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
): ):
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
...@@ -913,6 +914,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -913,6 +914,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills, batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i], seq_starts=prefill_metadata.chunked_context.starts[i],
kv_dtype=self.kv_cache_dtype,
scale=kv_scale,
) )
kv_c_normed = workspace[:toks]\ kv_c_normed = workspace[:toks]\
...@@ -972,6 +975,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -972,6 +975,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_pe: torch.Tensor, k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
) -> torch.Tensor: ) -> torch.Tensor:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -1006,7 +1010,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1006,7 +1010,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context: if has_context:
suffix_output, suffix_lse = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata) q, kv_c_and_k_pe_cache, attn_metadata, kv_scale)
output = torch.empty_like(suffix_output) output = torch.empty_like(suffix_output)
merge_attn_states( merge_attn_states(
...@@ -1095,7 +1099,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1095,7 +1099,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_prefill: if has_prefill:
output[num_decode_tokens:] = self._forward_prefill( output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata) attn_metadata, kv_scale=layer._k_scale)
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment