Unverified Commit a2ef9c54 authored by Tommy Chiang's avatar Tommy Chiang Committed by GitHub
Browse files

Use torch.unique_consecutive to check same element (#13637)

We use `torch.unique` here only to check whether every elements have
the same value.
Therefore, we can use `torch.unique_consecutive` here.

This function eliminates all but the first element from every consecutive
group of equivalent elements.
Like, if we apply this function to `[1, 2, 2, 1]`, it will result in
`[1, 2, 1]`.

As you could see, this is enough for checking whether every elements
have the same value.

Since `torch.unique_consecutive` do less thing, it is much more faster.
On my computer, it is 25x faster on GPU and 15x faster on CPU.
parent 95f888fd
...@@ -1457,7 +1457,7 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1457,7 +1457,7 @@ class BartForSequenceClassification(BartPretrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, : :, -1, :
......
...@@ -2668,7 +2668,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2668,7 +2668,7 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, : :, -1, :
......
...@@ -2522,7 +2522,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2522,7 +2522,7 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, : :, -1, :
......
...@@ -1463,7 +1463,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1463,7 +1463,7 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, : :, -1, :
......
...@@ -2972,7 +2972,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt ...@@ -2972,7 +2972,7 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
eos_mask = input_ids.eq(self.config.eos_token_id) eos_mask = input_ids.eq(self.config.eos_token_id)
if len(torch.unique(eos_mask.sum(1))) > 1: if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.") raise ValueError("All examples must have the same number of <eos> tokens.")
sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
:, -1, : :, -1, :
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment