"examples/vscode:/vscode.git/clone" did not exist on "d1eb88f42def4f7eafcf316feed137912ed522fa"
Unverified Commit 45c7b5b1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Small refactor (#15611)

parent c0864d98
......@@ -525,12 +525,6 @@ class GenerationMixin:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * decoder_start_token_id
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int:
if pad_token_id is None and eos_token_id is not None:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
return pad_token_id
def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int:
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
......@@ -1063,9 +1057,15 @@ class GenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id
if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -1075,11 +1075,6 @@ class GenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
if pad_token_id is None and eos_token_id is not None:
# special case if pad_token_id is not defined
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
pad_token_id = eos_token_id
# 2. Define model inputs
# inputs_tensor has to be defined
# model_input_name is defined if model-specific keyword input is passed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment