Unverified Commit 857fd8ea authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: missing generation config eos token setting in encoder-decoder tests (#29146)

parent 1c81132e
...@@ -473,6 +473,8 @@ class EncoderDecoderMixin: ...@@ -473,6 +473,8 @@ class EncoderDecoderMixin:
enc_dec_model.config.eos_token_id = None enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
# Bert does not have a bos token id, so use pad_token_id instead # Bert does not have a bos token id, so use pad_token_id instead
......
...@@ -377,6 +377,8 @@ class TFEncoderDecoderMixin: ...@@ -377,6 +377,8 @@ class TFEncoderDecoderMixin:
enc_dec_model.config.eos_token_id = None enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
# Bert does not have a bos token id, so use pad_token_id instead # Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate( generated_output = enc_dec_model.generate(
......
...@@ -351,6 +351,8 @@ class EncoderDecoderMixin: ...@@ -351,6 +351,8 @@ class EncoderDecoderMixin:
enc_dec_model.config.eos_token_id = None enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
inputs = input_values if input_features is None else input_features inputs = input_values if input_features is None else input_features
......
...@@ -308,6 +308,8 @@ class TFVisionEncoderDecoderMixin: ...@@ -308,6 +308,8 @@ class TFVisionEncoderDecoderMixin:
enc_dec_model.config.eos_token_id = None enc_dec_model.config.eos_token_id = None
if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"):
enc_dec_model.config.decoder.eos_token_id = None enc_dec_model.config.decoder.eos_token_id = None
if hasattr(enc_dec_model.generation_config, "eos_token_id"):
enc_dec_model.generation_config.eos_token_id = None
# Bert does not have a bos token id, so use pad_token_id instead # Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate( generated_output = enc_dec_model.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment