Unverified Commit 0f3f3c86 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Update attention interface in `Whisper` (#11784)


Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
parent b2785579
...@@ -106,6 +106,7 @@ class WhisperAttention(nn.Module): ...@@ -106,6 +106,7 @@ class WhisperAttention(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
attn_type=self.attn_type,
) )
def _init_qkv( def _init_qkv(
...@@ -134,12 +135,7 @@ class WhisperAttention(nn.Module): ...@@ -134,12 +135,7 @@ class WhisperAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
k,
v,
kv_cache,
attn_metadata,
attn_type=self.attn_type)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
...@@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention): ...@@ -164,6 +160,7 @@ class WhisperCrossAttention(WhisperAttention):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
attn_type=AttentionType.ENCODER_DECODER,
) )
def _init_qkv( def _init_qkv(
...@@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention): ...@@ -207,12 +204,13 @@ class WhisperCrossAttention(WhisperAttention):
else: else:
k = v = None k = v = None
attn_output = self.attn(q, attn_output = self.attn(
q,
k, k,
v, v,
kv_cache, kv_cache,
attn_metadata, attn_metadata,
attn_type=AttentionType.ENCODER_DECODER) )
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment