Unverified Commit 0f3f3c86 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Update attention interface in `Whisper` (#11784)


Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
parent b2785579
......@@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
def _init_qkv(
......@@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=self.attn_type)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.out_proj(attn_output)
......@@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
attn_type=AttentionType.ENCODER_DECODER,
)
def _init_qkv(
......@@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
else:
k = v = None
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
attn_output = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
output, _ = self.out_proj(attn_output)
......@@ -734,4 +732,4 @@ class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
loaded_weights = [(name, loaded_weight)
for name, loaded_weight in weights]
mapper = WeightsMapper({".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."})
return loader.load_weights(loaded_weights, mapper=mapper)
\ No newline at end of file
return loader.load_weights(loaded_weights, mapper=mapper)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment