Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5af5735f
Unverified
Commit
5af5735f
authored
Apr 28, 2022
by
Yih-Dar
Committed by
GitHub
Apr 28, 2022
Browse files
set eos_token_id to None to generate until max length (#16989)
Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
01562dac
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
24 additions
and
2 deletions
+24
-2
tests/encoder_decoder/test_modeling_encoder_decoder.py
tests/encoder_decoder/test_modeling_encoder_decoder.py
+4
-1
tests/encoder_decoder/test_modeling_tf_encoder_decoder.py
tests/encoder_decoder/test_modeling_tf_encoder_decoder.py
+6
-0
tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
...h_encoder_decoder/test_modeling_speech_encoder_decoder.py
+2
-1
tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
...ncoder_decoder/test_modeling_tf_vision_encoder_decoder.py
+6
-0
tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
...n_encoder_decoder/test_modeling_vision_encoder_decoder.py
+6
-0
No files found.
tests/encoder_decoder/test_modeling_encoder_decoder.py
View file @
5af5735f
...
...
@@ -413,6 +413,9 @@ class EncoderDecoderMixin:
enc_dec_model
=
EncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
# Generate until max length
if
hasattr
(
enc_dec_model
.
config
,
"eos_token_id"
):
enc_dec_model
.
config
.
eos_token_id
=
None
if
hasattr
(
enc_dec_model
.
config
,
"decoder"
)
and
hasattr
(
enc_dec_model
.
config
.
decoder
,
"eos_token_id"
):
enc_dec_model
.
config
.
decoder
.
eos_token_id
=
None
enc_dec_model
.
to
(
torch_device
)
...
...
tests/encoder_decoder/test_modeling_tf_encoder_decoder.py
View file @
5af5735f
...
...
@@ -314,6 +314,12 @@ class TFEncoderDecoderMixin:
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
# Generate until max length
if
hasattr
(
enc_dec_model
.
config
,
"eos_token_id"
):
enc_dec_model
.
config
.
eos_token_id
=
None
if
hasattr
(
enc_dec_model
.
config
,
"decoder"
)
and
hasattr
(
enc_dec_model
.
config
.
decoder
,
"eos_token_id"
):
enc_dec_model
.
config
.
decoder
.
eos_token_id
=
None
# Bert does not have a bos token id, so use pad_token_id instead
generated_output
=
enc_dec_model
.
generate
(
input_ids
,
decoder_start_token_id
=
enc_dec_model
.
config
.
decoder
.
pad_token_id
...
...
tests/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
View file @
5af5735f
...
...
@@ -347,6 +347,7 @@ class EncoderDecoderMixin:
enc_dec_model
.
to
(
torch_device
)
# make sure EOS token is set to None to prevent early stopping of generation
if
hasattr
(
enc_dec_model
.
config
,
"eos_token_id"
):
enc_dec_model
.
config
.
eos_token_id
=
None
if
hasattr
(
enc_dec_model
.
config
,
"decoder"
)
and
hasattr
(
enc_dec_model
.
config
.
decoder
,
"eos_token_id"
):
enc_dec_model
.
config
.
decoder
.
eos_token_id
=
None
...
...
tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py
View file @
5af5735f
...
...
@@ -300,6 +300,12 @@ class TFVisionEncoderDecoderMixin:
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFVisionEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
# Generate until max length
if
hasattr
(
enc_dec_model
.
config
,
"eos_token_id"
):
enc_dec_model
.
config
.
eos_token_id
=
None
if
hasattr
(
enc_dec_model
.
config
,
"decoder"
)
and
hasattr
(
enc_dec_model
.
config
.
decoder
,
"eos_token_id"
):
enc_dec_model
.
config
.
decoder
.
eos_token_id
=
None
# Bert does not have a bos token id, so use pad_token_id instead
generated_output
=
enc_dec_model
.
generate
(
pixel_values
,
decoder_start_token_id
=
enc_dec_model
.
config
.
decoder
.
pad_token_id
...
...
tests/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
View file @
5af5735f
...
...
@@ -269,6 +269,12 @@ class EncoderDecoderMixin:
def
check_encoder_decoder_model_generate
(
self
,
config
,
decoder_config
,
pixel_values
=
None
,
**
kwargs
):
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
VisionEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
# Generate until max length
if
hasattr
(
enc_dec_model
.
config
,
"eos_token_id"
):
enc_dec_model
.
config
.
eos_token_id
=
None
if
hasattr
(
enc_dec_model
.
config
,
"decoder"
)
and
hasattr
(
enc_dec_model
.
config
.
decoder
,
"eos_token_id"
):
enc_dec_model
.
config
.
decoder
.
eos_token_id
=
None
enc_dec_model
.
to
(
torch_device
)
inputs
=
pixel_values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment