Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4c3aac51
Unverified
Commit
4c3aac51
authored
Feb 06, 2025
by
Chen Zhang
Committed by
GitHub
Feb 05, 2025
Browse files
Merging PR #12536
Merged via CLI script
parent
bc1bdece
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
6 deletions
+22
-6
vllm/attention/layer.py
vllm/attention/layer.py
+22
-6
No files found.
vllm/attention/layer.py
View file @
4c3aac51
...
@@ -156,8 +156,12 @@ class Attention(nn.Module):
...
@@ -156,8 +156,12 @@ class Attention(nn.Module):
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
calculate_kv_scales
and
\
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
attn_metadata
.
enable_kv_scales_calculation
:
# directly, use `self.kv_cache` and
# `get_forward_context().attn_metadata` instead.
if
self
.
calculate_kv_scales
:
ctx_attn_metadata
=
get_forward_context
().
attn_metadata
if
ctx_attn_metadata
.
enable_kv_scales_calculation
:
self
.
calc_kv_scales
(
key
,
value
)
self
.
calc_kv_scales
(
key
,
value
)
if
self
.
use_output
:
if
self
.
use_output
:
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
...
@@ -172,15 +176,27 @@ class Attention(nn.Module):
...
@@ -172,15 +176,27 @@ class Attention(nn.Module):
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
unified_attention_with_output
(
query
,
key
,
value
,
output
,
forward_context
:
ForwardContext
=
get_forward_context
()
self
.
layer_name
)
ctx_attn_metadata
=
forward_context
.
attn_metadata
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
self_kv_cache
,
ctx_attn_metadata
,
output
=
output
)
else
:
else
:
torch
.
ops
.
vllm
.
unified_attention_with_output
(
torch
.
ops
.
vllm
.
unified_attention_with_output
(
query
,
key
,
value
,
output
,
self
.
layer_name
)
query
,
key
,
value
,
output
,
self
.
layer_name
)
return
output
.
view
(
-
1
,
hidden_size
)
return
output
.
view
(
-
1
,
hidden_size
)
else
:
else
:
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
return
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
forward_context
=
get_forward_context
()
ctx_attn_metadata
=
forward_context
.
attn_metadata
self_kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
self_kv_cache
,
ctx_attn_metadata
)
else
:
else
:
return
torch
.
ops
.
vllm
.
unified_attention
(
return
torch
.
ops
.
vllm
.
unified_attention
(
query
,
key
,
value
,
self
.
layer_name
)
query
,
key
,
value
,
self
.
layer_name
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment