Commit c1116011 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

small clean-up

parent 2e81b9d8
...@@ -845,7 +845,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -845,7 +845,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
encoder_inputs = input_ids encoder_inputs = input_ids
input_ids = torch.full( input_ids = torch.full(
(effective_batch_size * num_beams, 1), (effective_batch_size * num_beams, 1),
bos_token_id, bos_token_id, # TODO: wait for results of Bart CNN summarization
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
) )
...@@ -1082,7 +1082,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1082,7 +1082,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False: if self.config.is_encoder_decoder and do_sample is False:
# TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here # TODO: maybe give better naming
scores = self.prepare_scores_for_generation(scores, cur_len, max_length) scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# set eos token prob to zero if min_length is not reached # set eos token prob to zero if min_length is not reached
...@@ -1276,7 +1276,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1276,7 +1276,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:] return decoded[:, 1:]
return decoded return decoded
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment