Unverified Commit d5610b53 authored by Yanming Wang's avatar Yanming Wang Committed by GitHub
Browse files

[XLA] Improve t5 model performance (#18288)

parent e318cda9
...@@ -1331,8 +1331,6 @@ class LongT5PreTrainedModel(PreTrainedModel): ...@@ -1331,8 +1331,6 @@ class LongT5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
return shifted_input_ids return shifted_input_ids
...@@ -1414,7 +1412,7 @@ class LongT5Stack(LongT5PreTrainedModel): ...@@ -1414,7 +1412,7 @@ class LongT5Stack(LongT5PreTrainedModel):
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1] encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones( encoder_attention_mask = torch.ones(
......
...@@ -827,8 +827,6 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -827,8 +827,6 @@ class T5PreTrainedModel(PreTrainedModel):
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
return shifted_input_ids return shifted_input_ids
...@@ -944,7 +942,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -944,7 +942,7 @@ class T5Stack(T5PreTrainedModel):
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1] encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones( encoder_attention_mask = torch.ones(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment