Unverified Commit 147e8ce4 authored by Pietro Lesci's avatar Pietro Lesci Committed by GitHub
Browse files

Remove redundant code from T5 encoder mask creation (#27216)

* remove redundant code

* update

* add typecasting

* make `attention_mask` float again
parent a6c82d45
...@@ -997,18 +997,13 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -997,18 +997,13 @@ class MT5Stack(MT5PreTrainedModel):
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)
# initialize past_key_values with `None` if past does not exist # initialize past_key_values with `None` if past does not exist
if past_key_values is None: if past_key_values is None:
past_key_values = [None] * len(self.block) past_key_values = [None] * len(self.block)
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
...@@ -1019,7 +1014,9 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -1019,7 +1014,9 @@ class MT5Stack(MT5PreTrainedModel):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
......
...@@ -1024,18 +1024,13 @@ class T5Stack(T5PreTrainedModel): ...@@ -1024,18 +1024,13 @@ class T5Stack(T5PreTrainedModel):
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)
# initialize past_key_values with `None` if past does not exist # initialize past_key_values with `None` if past does not exist
if past_key_values is None: if past_key_values is None:
past_key_values = [None] * len(self.block) past_key_values = [None] * len(self.block)
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
...@@ -1046,7 +1041,9 @@ class T5Stack(T5PreTrainedModel): ...@@ -1046,7 +1041,9 @@ class T5Stack(T5PreTrainedModel):
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment