Commit e33ed12c authored by Patrick von Platen's avatar Patrick von Platen
Browse files

uncomment expression

parent 4220fd52
...@@ -945,10 +945,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -945,10 +945,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# scores for each sentence in the beam # scores for each sentence in the beam
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
# Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times # Greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
# if do_sample is False: if do_sample is False:
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,) beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states # cache compute states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment