Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f3f3c86
Unverified
Commit
0f3f3c86
authored
Jan 06, 2025
by
Roger Wang
Committed by
GitHub
Jan 07, 2025
Browse files
[Bugfix] Update attention interface in `Whisper` (#11784)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
b2785579
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
13 deletions
+11
-13
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+11
-13
No files found.
vllm/model_executor/models/whisper.py
View file @
0f3f3c86
...
@@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
...
@@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
attn_type
=
self
.
attn_type
,
)
)
def
_init_qkv
(
def
_init_qkv
(
...
@@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
...
@@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
k
,
v
,
kv_cache
,
attn_metadata
,
attn_type
=
self
.
attn_type
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
...
@@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
...
@@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
,
prefix
=
prefix
,
attn_type
=
AttentionType
.
ENCODER_DECODER
,
)
)
def
_init_qkv
(
def
_init_qkv
(
...
@@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
...
@@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
else
:
else
:
k
=
v
=
None
k
=
v
=
None
attn_output
=
self
.
attn
(
q
,
attn_output
=
self
.
attn
(
k
,
q
,
v
,
k
,
kv_cache
,
v
,
attn_metadata
,
kv_cache
,
attn_type
=
AttentionType
.
ENCODER_DECODER
)
attn_metadata
,
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
...
@@ -734,4 +732,4 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -734,4 +732,4 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
loaded_weights
=
[(
name
,
loaded_weight
)
loaded_weights
=
[(
name
,
loaded_weight
)
for
name
,
loaded_weight
in
weights
]
for
name
,
loaded_weight
in
weights
]
mapper
=
WeightsMapper
({
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
})
mapper
=
WeightsMapper
({
".fc1."
:
".mlp.fc1."
,
".fc2."
:
".mlp.fc2."
})
return
loader
.
load_weights
(
loaded_weights
,
mapper
=
mapper
)
return
loader
.
load_weights
(
loaded_weights
,
mapper
=
mapper
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment