Unverified Commit f5a49bfa authored by David del Río Medina's avatar David del Río Medina Committed by GitHub
Browse files

Replace assert statements with exceptions (#13871) (#13901)

* Replace assert statements with exceptions (#13871)

* Change f-strings when not needed (flake8)

* Replace assert statements with exceptions (#13871)

* Change f-strings when not needed (flake8)

* Improve error message as suggested by reviewer

* Fix identation bug

* Fix style errors
parent 70f186f6
...@@ -214,7 +214,17 @@ class BeamSearchScorer(BeamScorer): ...@@ -214,7 +214,17 @@ class BeamSearchScorer(BeamScorer):
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
assert batch_size == (input_ids.shape[0] // self.group_size) if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
f"size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
device = input_ids.device device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device) next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
...@@ -223,12 +233,10 @@ class BeamSearchScorer(BeamScorer): ...@@ -223,12 +233,10 @@ class BeamSearchScorer(BeamScorer):
for batch_idx, beam_hyp in enumerate(self._beam_hyps): for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]: if self._done[batch_idx]:
assert ( if self.num_beams < len(beam_hyp):
len(beam_hyp) >= self.num_beams raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
), f"Batch can only be done if at least {self.num_beams} beams have been generated" if eos_token_id is None or pad_token_id is None:
assert ( raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
eos_token_id is not None and pad_token_id is not None
), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
# pad the batch # pad the batch
next_beam_scores[batch_idx, :] = 0 next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id next_beam_tokens[batch_idx, :] = pad_token_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment