Unverified Commit 5af5735f authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

set eos_token_id to None to generate until max length (#16989)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 01562dac
...@@ -413,6 +413,9 @@ class EncoderDecoderMixin: ...@@ -413,6 +413,9 @@ class EncoderDecoderMixin:
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length # Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.config.decoder.eos_token_id = None
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
......
...@@ -314,6 +314,12 @@ class TFEncoderDecoderMixin: ...@@ -314,6 +314,12 @@ class TFEncoderDecoderMixin:
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = TFEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
# Bert does not have a bos token id, so use pad_token_id instead # Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate( generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
......
...@@ -347,6 +347,7 @@ class EncoderDecoderMixin: ...@@ -347,6 +347,7 @@ class EncoderDecoderMixin:
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
# make sure EOS token is set to None to prevent early stopping of generation # make sure EOS token is set to None to prevent early stopping of generation
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.config.decoder.eos_token_id = None
......
...@@ -300,6 +300,12 @@ class TFVisionEncoderDecoderMixin: ...@@ -300,6 +300,12 @@ class TFVisionEncoderDecoderMixin:
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = TFVisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
# Bert does not have a bos token id, so use pad_token_id instead # Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate( generated_output = enc_dec_model.generate(
pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id pixel_values, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
......
...@@ -269,6 +269,12 @@ class EncoderDecoderMixin: ...@@ -269,6 +269,12 @@ class EncoderDecoderMixin:
def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs): def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
# Generate until max length
if hasattr(enc_dec_model.config, "eos_token_id"):
enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
inputs = pixel_values inputs = pixel_values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment