"vscode:/vscode.git/clone" did not exist on "1367142afd363e2799e3299b9bbf14fcb5e848c0"
Unverified Commit 5ec368d7 authored by CHI LIU's avatar CHI LIU Committed by GitHub
Browse files

Correct eos_token_id settings in generate (#15403)

* Correct eos_token_id set in generate

* Set eos_token_id in test

* Correct eos_token_id set in generate

* Set eos_token_id in test
parent 39b5d1a6
......@@ -1049,6 +1049,8 @@ class GenerationMixin:
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is None and hasattr(self.config, "decoder"):
eos_token_id = self.config.decoder.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
......
......@@ -395,6 +395,9 @@ class EncoderDecoderMixin:
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
enc_dec_model.config.decoder.eos_token_id = None
enc_dec_model.to(torch_device)
# Bert does not have a bos token id, so use pad_token_id instead
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment