Unverified Commit f464f10a authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Remove outdated code (#11331)

* remove update function

* update

* refactor more

* refactor
parent bfd83c17
......@@ -483,31 +483,6 @@ class GenerationMixin:
model_kwargs["encoder_outputs"] = encoder_outputs
return input_ids, model_kwargs
@staticmethod
def _init_sequence_length_for_generation(
input_ids: torch.LongTensor, max_length: int
) -> Tuple[torch.Tensor, torch.Tensor, int]:
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
sequence_lengths = input_ids.new(input_ids.shape[0]).fill_(max_length)
cur_len = input_ids.shape[-1]
return sequence_lengths, unfinished_sequences, cur_len
@staticmethod
def _update_seq_length_for_generation(
sequence_lengths: torch.LongTensor,
unfinished_sequences: torch.LongTensor,
cur_len: int,
is_eos_in_next_token: torch.BoolTensor,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
# check if sentence is not finished yet
is_sent_unfinished = unfinished_sequences.mul(is_eos_in_next_token.long()).bool()
# update sentence length
sequence_lengths = sequence_lengths.masked_fill(is_sent_unfinished, cur_len)
unfinished_sequences = unfinished_sequences.mul((~is_eos_in_next_token).long())
return sequence_lengths, unfinished_sequences
@staticmethod
def _update_model_kwargs_for_generation(
outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False
......@@ -1271,10 +1246,9 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[-1]
this_peer_finished = False # used by synced_gpus only
while cur_len < max_length:
......@@ -1330,29 +1304,23 @@ class GenerationMixin:
# argmax
next_tokens = torch.argmax(next_tokens_scores, dim=-1)
# add code that transforms next_tokens to tokens_to_add
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# add token and increase length by one
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# increase cur_len
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
......@@ -1511,10 +1479,9 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# init sequence length tensors
sequence_lengths, unfinished_sequences, cur_len = self._init_sequence_length_for_generation(
input_ids, max_length
)
# keep track of which sequences are already finished
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
cur_len = input_ids.shape[-1]
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
......@@ -1571,32 +1538,25 @@ class GenerationMixin:
# sample
probs = F.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# add code that transforms next_tokens to tokens_to_add
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
assert pad_token_id is not None, "If eos_token_id is defined, make sure that pad_token_id is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# add token and increase length by one
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# update sequence length
if eos_token_id is not None:
sequence_lengths, unfinished_sequences = self._update_seq_length_for_generation(
sequence_lengths, unfinished_sequences, cur_len, next_tokens == eos_token_id
)
# update model kwargs
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
# increase cur_len
cur_len = cur_len + 1
# stop when there is a </s> in each sentence, or if we exceed the maximum length
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
# stop when each sentence is finished, or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
if not synced_gpus:
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment