"examples/trials/vscode:/vscode.git/clone" did not exist on "d1b1e7b311380d9ab0454c85c851d8159aa5b67e"
Unverified Commit ebd52589 authored by raghavanone's avatar raghavanone Committed by GitHub
Browse files

Change the way tensor is reshaped in BartAttention (from .view to .reshape) (#21860)

* Change the .view call to .reshape

* Change the .view call to .reshape to all the copies from bart attention

* Fix copies and style

* Fix copies and style

* Fix copies and style

* Fix copies and style

* Fix copies and style

* Revert unneccessary changes

* Revert unneccessary changes

* Revert unneccessary changes

* Revert unneccessary changes
parent f71873c5
......@@ -574,8 +574,8 @@ class Wav2Vec2Attention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
......@@ -319,8 +319,8 @@ class WhisperAttention(nn.Module):
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
key_states = key_states.reshape(*proj_shape)
value_states = value_states.reshape(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment