"examples/vscode:/vscode.git/clone" did not exist on "2c862dcb82257bf8f34ca2cbc618c3112af540ea"
Unverified Commit 9e4df7c4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: replace breaks by a loop condition (#29662)



* replace breaks by a loop condition

* Update src/transformers/generation/utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 28de2f4d
...@@ -1778,6 +1778,24 @@ class GenerationMixin: ...@@ -1778,6 +1778,24 @@ class GenerationMixin:
return result return result
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
"""
Returns whether there are still unfinished sequences in the device. The existence of unfinished sequences is
fed through `this_peer_finished`. ZeRO stage 3-friendly.
"""
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
return False
elif this_peer_finished:
return False
return True
def contrastive_search(self, *args, **kwargs): def contrastive_search(self, *args, **kwargs):
logger.warning_once( logger.warning_once(
"Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a " "Calling `contrastive_search` directly is deprecated and will be removed in v4.41. Use `generate` or a "
...@@ -1939,19 +1957,9 @@ class GenerationMixin: ...@@ -1939,19 +1957,9 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only this_peer_finished = False
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None: if model_kwargs.get("past_key_values") is None:
...@@ -2187,12 +2195,7 @@ class GenerationMixin: ...@@ -2187,12 +2195,7 @@ class GenerationMixin:
# stop when each sentence is finished # stop when each sentence is finished
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None: if streamer is not None:
streamer.end() streamer.end()
...@@ -2395,6 +2398,7 @@ class GenerationMixin: ...@@ -2395,6 +2398,7 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
this_peer_finished = False
batch_size, cur_len = ( batch_size, cur_len = (
model_kwargs["attention_mask"].shape model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None if model_kwargs.get("attention_mask", None) is not None
...@@ -2403,18 +2407,7 @@ class GenerationMixin: ...@@ -2403,18 +2407,7 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
...@@ -2480,13 +2473,7 @@ class GenerationMixin: ...@@ -2480,13 +2473,7 @@ class GenerationMixin:
) )
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None: if streamer is not None:
streamer.end() streamer.end()
...@@ -2699,6 +2686,7 @@ class GenerationMixin: ...@@ -2699,6 +2686,7 @@ class GenerationMixin:
) )
# keep track of which sequences are already finished # keep track of which sequences are already finished
this_peer_finished = False
batch_size, cur_len = ( batch_size, cur_len = (
model_kwargs["attention_mask"].shape model_kwargs["attention_mask"].shape
if model_kwargs.get("attention_mask", None) is not None if model_kwargs.get("attention_mask", None) is not None
...@@ -2707,19 +2695,7 @@ class GenerationMixin: ...@@ -2707,19 +2695,7 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# auto-regressive generation
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs # prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
...@@ -2787,13 +2763,7 @@ class GenerationMixin: ...@@ -2787,13 +2763,7 @@ class GenerationMixin:
) )
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None: if streamer is not None:
streamer.end() streamer.end()
...@@ -3052,20 +3022,11 @@ class GenerationMixin: ...@@ -3052,20 +3022,11 @@ class GenerationMixin:
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# if sequential is True, split the input to batches of batch_size and run sequentially # if sequential is True, split the input to batches of batch_size and run sequentially
...@@ -3192,9 +3153,6 @@ class GenerationMixin: ...@@ -3192,9 +3153,6 @@ class GenerationMixin:
cur_len = cur_len + 1 cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
...@@ -3441,20 +3399,10 @@ class GenerationMixin: ...@@ -3441,20 +3399,10 @@ class GenerationMixin:
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True: while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self( outputs = self(
...@@ -3549,9 +3497,6 @@ class GenerationMixin: ...@@ -3549,9 +3497,6 @@ class GenerationMixin:
cur_len = cur_len + 1 cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True this_peer_finished = True
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
...@@ -3804,20 +3749,10 @@ class GenerationMixin: ...@@ -3804,20 +3749,10 @@ class GenerationMixin:
beam_scores[:, ::num_sub_beams] = 0 beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True: while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# predicted tokens in cur_len step # predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
...@@ -3955,9 +3890,6 @@ class GenerationMixin: ...@@ -3955,9 +3890,6 @@ class GenerationMixin:
cur_len = cur_len + 1 cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True this_peer_finished = True
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
...@@ -4213,20 +4145,10 @@ class GenerationMixin: ...@@ -4213,20 +4145,10 @@ class GenerationMixin:
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
this_peer_finished = False # used by synced_gpus only this_peer_finished = False
decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder decoder_prompt_len = input_ids.shape[-1] # record the prompt length of decoder
while True: while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self( outputs = self(
...@@ -4320,9 +4242,6 @@ class GenerationMixin: ...@@ -4320,9 +4242,6 @@ class GenerationMixin:
cur_len = cur_len + 1 cur_len = cur_len + 1
if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
if not synced_gpus:
break
else:
this_peer_finished = True this_peer_finished = True
sequence_outputs = constrained_beam_scorer.finalize( sequence_outputs = constrained_beam_scorer.finalize(
...@@ -4553,18 +4472,8 @@ class GenerationMixin: ...@@ -4553,18 +4472,8 @@ class GenerationMixin:
# other auxiliary variables # other auxiliary variables
max_len = stopping_criteria[0].max_length max_len = stopping_criteria[0].max_length
this_peer_finished = False # used by synced_gpus only this_peer_finished = False
while True: while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
# 1. Fetch candidate sequences from a `CandidateGenerator` # 1. Fetch candidate sequences from a `CandidateGenerator`
...@@ -4733,13 +4642,7 @@ class GenerationMixin: ...@@ -4733,13 +4642,7 @@ class GenerationMixin:
) )
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None: if streamer is not None:
streamer.end() streamer.end()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment