Unverified Commit ebd52589 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Change the way tensor is reshaped in BartAttention (from .view to .reshape) (#21860)

* Change the .view call to .reshape

* Change the .view call to .reshape to all the copies from bart attention

* Fix copies and style

* Fix copies and style

* Fix copies and style

* Fix copies and style

* Fix copies and style

* Revert unneccessary changes

* Revert unneccessary changes

* Revert unneccessary changes

* Revert unneccessary changes
parent f71873c5
...@@ -229,8 +229,8 @@ class BartAttention(nn.Module): ...@@ -229,8 +229,8 @@ class BartAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -1288,8 +1288,8 @@ class BigBirdPegasusDecoderAttention(nn.Module): ...@@ -1288,8 +1288,8 @@ class BigBirdPegasusDecoderAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -190,8 +190,8 @@ class BioGptAttention(nn.Module): ...@@ -190,8 +190,8 @@ class BioGptAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -216,8 +216,8 @@ class BlenderbotAttention(nn.Module): ...@@ -216,8 +216,8 @@ class BlenderbotAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -213,8 +213,8 @@ class BlenderbotSmallAttention(nn.Module): ...@@ -213,8 +213,8 @@ class BlenderbotSmallAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -419,8 +419,8 @@ class Data2VecAudioAttention(nn.Module): ...@@ -419,8 +419,8 @@ class Data2VecAudioAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -451,8 +451,8 @@ class GPTSanJapaneseAttention(nn.Module): ...@@ -451,8 +451,8 @@ class GPTSanJapaneseAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -482,8 +482,8 @@ class HubertAttention(nn.Module): ...@@ -482,8 +482,8 @@ class HubertAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -284,8 +284,8 @@ class M2M100Attention(nn.Module): ...@@ -284,8 +284,8 @@ class M2M100Attention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -231,8 +231,8 @@ class MarianAttention(nn.Module): ...@@ -231,8 +231,8 @@ class MarianAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -225,8 +225,8 @@ class MBartAttention(nn.Module): ...@@ -225,8 +225,8 @@ class MBartAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -231,8 +231,8 @@ class PegasusAttention(nn.Module): ...@@ -231,8 +231,8 @@ class PegasusAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -240,8 +240,8 @@ class PegasusXAttention(nn.Module): ...@@ -240,8 +240,8 @@ class PegasusXAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -224,8 +224,8 @@ class PLBartAttention(nn.Module): ...@@ -224,8 +224,8 @@ class PLBartAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -482,8 +482,8 @@ class SEWAttention(nn.Module): ...@@ -482,8 +482,8 @@ class SEWAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -291,8 +291,8 @@ class Speech2TextAttention(nn.Module): ...@@ -291,8 +291,8 @@ class Speech2TextAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -237,8 +237,8 @@ class Speech2Text2Attention(nn.Module): ...@@ -237,8 +237,8 @@ class Speech2Text2Attention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -717,8 +717,8 @@ class TimeSeriesTransformerAttention(nn.Module): ...@@ -717,8 +717,8 @@ class TimeSeriesTransformerAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -517,8 +517,8 @@ class UniSpeechAttention(nn.Module): ...@@ -517,8 +517,8 @@ class UniSpeechAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
...@@ -531,8 +531,8 @@ class UniSpeechSatAttention(nn.Module): ...@@ -531,8 +531,8 @@ class UniSpeechSatAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim) proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape) key_states = key_states.reshape(*proj_shape)
value_states = value_states.view(*proj_shape) value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1) src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment