Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f3f3c86
Unverified
Commit
0f3f3c86
authored
Jan 06, 2025
by
Roger Wang
Committed by
GitHub
Jan 07, 2025
Browse files
[Bugfix] Update attention interface in `Whisper` (#11784)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
b2785579
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
13 deletions
+11
-13
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+11
-13
No files found.
vllm/model_executor/models/whisper.py
View file @
0f3f3c86
...
@@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
...
@@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
self
.
attn_type
,
)
)
def
_init_qkv
(
def
_init_qkv
(
...
@@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
...
@@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
self
.
attn_type
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
...
@@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
...
@@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
attn_type
=
AttentionType
.
ENCODER_DECODER
,
)
)
def
_init_qkv
(
def
_init_qkv
(
...
@@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
...
@@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
else
:
else
:
k
=
v
=
None
k
=
v
=
None
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
q
,
k
,
k
,
v
,
v
,
kv_cache
,
kv_cache
,
attn_metadata
,
attn_metadata
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment