Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
efae6645
Unverified
Commit
efae6645
authored
Jun 27, 2020
by
Sylvain Gugger
Committed by
GitHub
Jun 27, 2020
Browse files
Fix `xxx_length` behavior when using XLNet in pipeline (#5319)
parent
393b8dc0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
src/transformers/pipelines.py
src/transformers/pipelines.py
+12
-2
No files found.
src/transformers/pipelines.py
View file @
efae6645
...
...
@@ -586,7 +586,7 @@ class TextGenerationPipeline(Pipeline):
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing.
<eod> </s> <eos>
"""
with people, even a bishop, begging for his blessing. """
ALLOWED_MODELS
=
[
"XLNetLMHeadModel"
,
...
...
@@ -619,8 +619,18 @@ class TextGenerationPipeline(Pipeline):
# Manage correct placement of the tensors
with
self
.
device_placement
():
if
self
.
model
.
__class__
.
__name__
in
[
"XLNetLMHeadModel"
,
"TransfoXLLMHeadModel"
]:
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
padding_text
=
self
.
PADDING_TEXT
+
self
.
tokenizer
.
eos_token
padding
=
self
.
_parse_and_tokenize
(
padding_text
,
padding
=
False
,
add_special_tokens
=
False
)
# This impacts max_length and min_length argument that need adjusting.
padding_length
=
padding
[
"input_ids"
].
shape
[
-
1
]
if
"max_length"
in
generate_kwargs
and
generate_kwargs
[
"max_length"
]
is
not
None
:
generate_kwargs
[
"max_length"
]
+=
padding_length
if
"min_length"
in
generate_kwargs
and
generate_kwargs
[
"min_length"
]
is
not
None
:
generate_kwargs
[
"min_length"
]
+=
padding_length
inputs
=
self
.
_parse_and_tokenize
(
self
.
PADDING_TEXT
+
prompt_text
,
padding
=
False
,
add_special_tokens
=
False
padding_text
+
prompt_text
,
padding
=
False
,
add_special_tokens
=
False
)
else
:
inputs
=
self
.
_parse_and_tokenize
(
prompt_text
,
padding
=
False
,
add_special_tokens
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment