Unverified Commit f464f10a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Remove outdated code (#11331)

* remove update function

* update

* refactor more

* refactor
parent bfd83c17
...@@ -483,31 +483,6 @@ class GenerationMixin: ...@@ -483,31 +483,6 @@ class GenerationMixin:
model_kwargs["encoder_outputs"] = encoder_outputs model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs return input_ids, model_kwargs
@staticmethod
def _init_sequence_length_for_generation(
input_ids: torch.LongTensor, max_length: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length)
cur_len = input_ids.shape[-1]
return sequence_lengths, unfinished_sequences, cur_len
@staticmethod
def _update_seq_length_for_generation(
sequence_lengths: torch.LongTensor,
unfinished_sequences: torch.LongTensor,
cur_len: int,
is_eos_in_next_token: torch.BoolTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
# check if sentence is not finished yet
is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool()
# update sentence length
sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len)
unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long())
return sequence_lengths, unfinished_sequences
@staticmethod @staticmethod
def _update_model_kwargs_for_generation( def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
...@@ -1271,10 +1246,9 @@ class GenerationMixin: ...@@ -1271,10 +1246,9 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
) )
# init sequence length tensors # keep track of which sequences are already finished
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
input_ids, max_length cur_len = input_ids.shape[-1]
)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
while cur_len < max_length: while cur_len < max_length:
...@@ -1330,29 +1304,23 @@ class GenerationMixin: ...@@ -1330,29 +1304,23 @@ class GenerationMixin:
# argmax # argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1) next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# add code that transforms next_tokens to tokens_to_add # finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# add token and increase length by one # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
# increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximum length # if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus: if not synced_gpus:
break break
...@@ -1511,10 +1479,9 @@ class GenerationMixin: ...@@ -1511,10 +1479,9 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
) )
# init sequence length tensors # keep track of which sequences are already finished
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation( unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
input_ids, max_length cur_len = input_ids.shape[-1]
)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False # used by synced_gpus only
# auto-regressive generation # auto-regressive generation
...@@ -1571,32 +1538,25 @@ class GenerationMixin: ...@@ -1571,32 +1538,25 @@ class GenerationMixin:
# sample # sample
probs = F.softmax(next_token_scores, dim=-1) probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# add code that transforms next_tokens to tokens_to_add # finished sentences should have their next token be a padding token
if eos_token_id is not None: if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined." assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# add token and increase length by one # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
) )
# increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximum length # if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus: if not synced_gpus:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment