"tests/kernels/attention/test_encoder_decoder_attn.py" did not exist on "e489ad7a210f4234db696d1f2749d5f3662fa65b"
Commit a0895c00 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev-yql-kvfp8' into v0.9.2-dev

parents 4020670f 8b7daa0d
......@@ -2162,7 +2162,17 @@ def gather_cache(src_cache: torch.Tensor,
block_table: torch.Tensor,
cu_seq_lens: torch.Tensor,
batch_size: int,
seq_starts: Optional[torch.Tensor] = None) -> None:
seq_starts: Optional[torch.Tensor] = None,
kv_dtype = "auto",
scale: float = 1.0,
) -> None:
#支持"kv cache fp8"
if kv_dtype == "fp8":
dst_fp8 = torch.zeros(dst.shape, dtype=torch.uint8, device=dst.device)
convert_fp8(dst_fp8, dst, scale, kv_dtype)
torch.ops._C_cache_ops.gather_cache(src_cache, dst_fp8, block_table,
cu_seq_lens, batch_size, seq_starts)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
......
......@@ -1179,6 +1179,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
):
prefill_metadata = attn_metadata.prefill_metadata
assert prefill_metadata is not None
......@@ -1207,6 +1208,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
batch_size=prefill_metadata.num_prefills,
seq_starts=prefill_metadata.context_chunk_starts[i],
kv_dtype=self.kv_cache_dtype,
scale=kv_scale,
)
kv_c_normed = workspace[:toks]\
......@@ -1262,6 +1265,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
) -> torch.Tensor:
prefill_metadata = attn_metadata.prefill_metadata
......@@ -1297,7 +1301,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
q, kv_c_and_k_pe_cache, attn_metadata, kv_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
......@@ -1387,7 +1391,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
attn_metadata, kv_scale=layer._k_scale)
if has_decode:
decode_q_nope, decode_q_pe = decode_q.split(
......
......@@ -205,7 +205,8 @@ class Attention(nn.Module):
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
self.calc_kv_scales(query, key, value)
if self.use_output:
output_shape = (output_shape
......
......@@ -894,6 +894,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
......@@ -913,6 +914,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
kv_dtype=self.kv_cache_dtype,
scale=kv_scale,
)
kv_c_normed = workspace[:toks]\
......@@ -976,6 +979,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
) -> torch.Tensor:
assert attn_metadata.prefill is not None
......@@ -1015,7 +1019,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
q, kv_c_and_k_pe_cache, attn_metadata, kv_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
......@@ -1104,7 +1108,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
attn_metadata, kv_scale=layer._k_scale)
if has_decode:
assert attn_metadata.decode is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment