Unverified Commit 648d0deb authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

fix typo in Bart's attention (#21898)

parent c87654dc
...@@ -621,7 +621,7 @@ class Wav2Vec2Attention(nn.Module): ...@@ -621,7 +621,7 @@ class Wav2Vec2Attention(nn.Module):
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}" f" {attn_output.size()}"
) )
...@@ -629,7 +629,7 @@ class Wav2Vec2Attention(nn.Module): ...@@ -629,7 +629,7 @@ class Wav2Vec2Attention(nn.Module):
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism. # partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
......
...@@ -366,7 +366,7 @@ class WhisperAttention(nn.Module): ...@@ -366,7 +366,7 @@ class WhisperAttention(nn.Module):
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
raise ValueError( raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
f" {attn_output.size()}" f" {attn_output.size()}"
) )
...@@ -374,7 +374,7 @@ class WhisperAttention(nn.Module): ...@@ -374,7 +374,7 @@ class WhisperAttention(nn.Module):
attn_output = attn_output.transpose(1, 2) attn_output = attn_output.transpose(1, 2)
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned aross GPUs when using tensor-parallelism. # partitioned across GPUs when using tensor-parallelism.
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_output) attn_output = self.out_proj(attn_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment