Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3f9cb335
Unverified
Commit
3f9cb335
authored
Aug 16, 2023
by
Joao Gante
Committed by
GitHub
Aug 16, 2023
Browse files
Generate: fix default max length warning (#25539)
parent
e13d5b60
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
4 deletions
+30
-4
src/transformers/generation/flax_utils.py
src/transformers/generation/flax_utils.py
+1
-1
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+1
-1
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+1
-1
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+1
-1
tests/generation/test_utils.py
tests/generation/test_utils.py
+26
-0
No files found.
src/transformers/generation/flax_utils.py
View file @
3f9cb335
...
...
@@ -377,7 +377,7 @@ class FlaxGenerationMixin:
# Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!
=
20
:
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
=
=
20
:
# 20 is the default max_length of the generation config
warnings
.
warn
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) "
...
...
src/transformers/generation/tf_utils.py
View file @
3f9cb335
...
...
@@ -829,7 +829,7 @@ class TFGenerationMixin:
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
shape_list
(
input_ids
)[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!
=
20
:
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
=
=
20
:
# 20 is the default max_length of the generation config
warnings
.
warn
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) "
...
...
src/transformers/generation/utils.py
View file @
3f9cb335
...
...
@@ -1249,7 +1249,7 @@ class GenerationMixin:
"""Performs validation related to the resulting generated length"""
# 1. Max length warnings related to poor parameterization
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!
=
20
:
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
=
=
20
:
# 20 is the default max_length of the generation config
warnings
.
warn
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) to control the"
...
...
src/transformers/models/musicgen/modeling_musicgen.py
View file @
3f9cb335
...
...
@@ -1300,7 +1300,7 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
!
=
20
:
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
and
generation_config
.
max_length
=
=
20
:
logger
.
warning
(
f
"Using the model-agnostic default `max_length` (=
{
generation_config
.
max_length
}
) "
"to control the generation length. recommend setting `max_new_tokens` to control the maximum length of the generation."
,
...
...
tests/generation/test_utils.py
View file @
3f9cb335
...
...
@@ -16,6 +16,7 @@
import
inspect
import
unittest
import
warnings
import
numpy
as
np
...
...
@@ -2844,3 +2845,28 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
with
self
.
assertRaises
(
TypeError
):
# FakeEncoder.forward() accepts **kwargs -> no filtering -> type error due to unexpected input "foo"
bart_model
.
generate
(
input_ids
,
foo
=
"bar"
)
def
test_default_max_length_warning
(
self
):
model
=
AutoModelForCausalLM
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
).
to
(
torch_device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"hf-internal-testing/tiny-random-gpt2"
)
model
.
config
.
pad_token_id
=
tokenizer
.
eos_token_id
text
=
"Hello world"
tokenized_inputs
=
tokenizer
([
text
],
return_tensors
=
"pt"
)
input_ids
=
tokenized_inputs
.
input_ids
.
to
(
torch_device
)
# Default generation config value of 20 -> emits warning
with
self
.
assertWarns
(
UserWarning
):
model
.
generate
(
input_ids
)
# Explicitly setting max_length to 20 -> no warning
with
warnings
.
catch_warnings
(
record
=
True
)
as
warning_list
:
model
.
generate
(
input_ids
,
max_length
=
20
)
self
.
assertEqual
(
len
(
warning_list
),
0
)
# Generation config max_length != 20 -> no warning
with
warnings
.
catch_warnings
(
record
=
True
)
as
warning_list
:
model
.
generation_config
.
max_length
=
10
model
.
generation_config
.
_from_model_config
=
False
# otherwise model.config.max_length=20 takes precedence
model
.
generate
(
input_ids
)
self
.
assertEqual
(
len
(
warning_list
),
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment