Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
00298e09
"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "21eb2c3372fb6447ef36bee44ff7af79a330ffec"
Unverified
Commit
00298e09
authored
Oct 12, 2024
by
Xiang Xu
Committed by
GitHub
Oct 12, 2024
Browse files
[Bugfix] Fix bug of xformer prefill for encoder-decoder (#9026)
parent
89feb4c8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
11 deletions
+18
-11
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+18
-11
No files found.
vllm/attention/backends/xformers.py
View file @
00298e09
...
@@ -559,25 +559,32 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -559,25 +559,32 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
k_scale
,
v_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
if
attn_type
==
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
else
:
# Encoder attention - chunked prefill is not applicable;
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# derive token-count from query shape & and treat them
# as 100% prefill tokens
# as 100% prefill tokens
assert
attn_metadata
.
num_encoder_tokens
is
not
None
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_encoder_tokens
=
attn_metadata
.
num_encoder_tokens
num_decode_tokens
=
0
num_decode_tokens
=
0
elif
attn_type
==
AttentionType
.
DECODER
:
if
attn_type
==
AttentionType
.
DECODER
:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_encoder_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
# Only enforce this shape-constraint for decoder
# Only enforce this shape-constraint for decoder
# self-attention
# self-attention
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
key
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
assert
value
.
shape
[
0
]
==
num_prefill_tokens
+
num_decode_tokens
else
:
# attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
if
attn_metadata
.
num_encoder_tokens
is
not
None
:
num_encoder_tokens
=
attn_metadata
.
num_encoder_tokens
else
:
num_encoder_tokens
=
attn_metadata
.
num_prefill_tokens
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
# Query for decode. KV is not needed because it is already cached.
...
@@ -585,8 +592,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -585,8 +592,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# QKV for prefill.
# QKV for prefill.
query
=
query
[:
num_prefill_tokens
]
query
=
query
[:
num_prefill_tokens
]
if
key
is
not
None
and
value
is
not
None
:
if
key
is
not
None
and
value
is
not
None
:
key
=
key
[:
num_
prefill
_tokens
]
key
=
key
[:
num_
encoder
_tokens
]
value
=
value
[:
num_
prefill
_tokens
]
value
=
value
[:
num_
encoder
_tokens
]
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment