Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
efae6645
Unverified
Commit
efae6645
authored
Jun 27, 2020
by
Sylvain Gugger
Committed by
GitHub
Jun 27, 2020
Browse files
Fix `xxx_length` behavior when using XLNet in pipeline (#5319)
parent
393b8dc0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
src/transformers/pipelines.py
src/transformers/pipelines.py
+12
-2
No files found.
src/transformers/pipelines.py
View file @
efae6645
...
@@ -586,7 +586,7 @@ class TextGenerationPipeline(Pipeline):
...
@@ -586,7 +586,7 @@ class TextGenerationPipeline(Pipeline):
father initially slaps him for making such an accusation, Rasputin watches as the
father initially slaps him for making such an accusation, Rasputin watches as the
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing.
<eod> </s> <eos>
"""
with people, even a bishop, begging for his blessing. """
ALLOWED_MODELS
=
[
ALLOWED_MODELS
=
[
"XLNetLMHeadModel"
,
"XLNetLMHeadModel"
,
...
@@ -619,8 +619,18 @@ class TextGenerationPipeline(Pipeline):
...
@@ -619,8 +619,18 @@ class TextGenerationPipeline(Pipeline):
# Manage correct placement of the tensors
# Manage correct placement of the tensors
with
self
.
device_placement
():
with
self
.
device_placement
():
if
self
.
model
.
__class__
.
__name__
in
[
"XLNetLMHeadModel"
,
"TransfoXLLMHeadModel"
]:
if
self
.
model
.
__class__
.
__name__
in
[
"XLNetLMHeadModel"
,
"TransfoXLLMHeadModel"
]:
# For XLNet and TransformerXL we had an article to the prompt to give more state to the model.
padding_text
=
self
.
PADDING_TEXT
+
self
.
tokenizer
.
eos_token
padding
=
self
.
_parse_and_tokenize
(
padding_text
,
padding
=
False
,
add_special_tokens
=
False
)
# This impacts max_length and min_length argument that need adjusting.
padding_length
=
padding
[
"input_ids"
].
shape
[
-
1
]
if
"max_length"
in
generate_kwargs
and
generate_kwargs
[
"max_length"
]
is
not
None
:
generate_kwargs
[
"max_length"
]
+=
padding_length
if
"min_length"
in
generate_kwargs
and
generate_kwargs
[
"min_length"
]
is
not
None
:
generate_kwargs
[
"min_length"
]
+=
padding_length
inputs
=
self
.
_parse_and_tokenize
(
inputs
=
self
.
_parse_and_tokenize
(
self
.
PADDING_TEXT
+
prompt_text
,
padding
=
False
,
add_special_tokens
=
False
padding_text
+
prompt_text
,
padding
=
False
,
add_special_tokens
=
False
)
)
else
:
else
:
inputs
=
self
.
_parse_and_tokenize
(
prompt_text
,
padding
=
False
,
add_special_tokens
=
False
)
inputs
=
self
.
_parse_and_tokenize
(
prompt_text
,
padding
=
False
,
add_special_tokens
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment