Unverified Commit 4c3aac51 authored by Chen Zhang's avatar Chen Zhang Committed by GitHub
Browse files

Merging PR #12536

Merged via CLI script
parent bc1bdece
...@@ -156,9 +156,13 @@ class Attention(nn.Module): ...@@ -156,9 +156,13 @@ class Attention(nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if self.calculate_kv_scales and \ # NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
attn_metadata.enable_kv_scales_calculation: # directly, use `self.kv_cache` and
self.calc_kv_scales(key, value) # `get_forward_context().attn_metadata` instead.
if self.calculate_kv_scales:
ctx_attn_metadata = get_forward_context().attn_metadata
if ctx_attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output: if self.use_output:
output = torch.empty_like(query) output = torch.empty_like(query)
hidden_size = query.size(-1) hidden_size = query.size(-1)
...@@ -172,15 +176,27 @@ class Attention(nn.Module): ...@@ -172,15 +176,27 @@ class Attention(nn.Module):
if value is not None: if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call: if self.use_direct_call:
unified_attention_with_output(query, key, value, output, forward_context: ForwardContext = get_forward_context()
self.layer_name) ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
ctx_attn_metadata,
output=output)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name) query, key, value, output, self.layer_name)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
if self.use_direct_call: if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name) forward_context = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, ctx_attn_metadata)
else: else:
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name) query, key, value, self.layer_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment