Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
18058574
Unverified
Commit
18058574
authored
Apr 28, 2020
by
Patrick von Platen
Committed by
GitHub
Apr 28, 2020
Browse files
[Generation] Generation should allow to start with empty prompt (#3993)
* fix empty prompt * fix length in generation pipeline
parent
52679fbc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
15 deletions
+25
-15
examples/run_generation.py
examples/run_generation.py
+6
-1
src/transformers/pipelines.py
src/transformers/pipelines.py
+19
-14
No files found.
examples/run_generation.py
View file @
18058574
...
...
@@ -221,8 +221,13 @@ def main():
encoded_prompt
=
tokenizer
.
encode
(
prompt_text
,
add_special_tokens
=
False
,
return_tensors
=
"pt"
)
encoded_prompt
=
encoded_prompt
.
to
(
args
.
device
)
if
encoded_prompt
.
size
()[
-
1
]
==
0
:
input_ids
=
None
else
:
input_ids
=
encoded_prompt
output_sequences
=
model
.
generate
(
input_ids
=
encoded_prompt
,
input_ids
=
input_ids
,
max_length
=
args
.
length
+
len
(
encoded_prompt
[
0
]),
temperature
=
args
.
temperature
,
top_k
=
args
.
k
,
...
...
src/transformers/pipelines.py
View file @
18058574
...
...
@@ -563,14 +563,19 @@ class TextGenerationPipeline(Pipeline):
else
:
inputs
=
self
.
_parse_and_tokenize
(
prompt_text
)
if
self
.
framework
==
"pt"
:
# set input_ids to None to allow empty prompt
if
inputs
[
"input_ids"
].
shape
[
-
1
]
==
0
:
inputs
[
"input_ids"
]
=
None
inputs
[
"attention_mask"
]
=
None
if
self
.
framework
==
"pt"
and
inputs
[
"input_ids"
]
is
not
None
:
inputs
=
self
.
ensure_tensor_on_device
(
**
inputs
)
input_ids
=
inputs
[
"input_ids"
]
# Ensure that batch size = 1 (batch generation not allowed for now)
assert
(
input_ids
.
shape
[
0
]
==
1
input_ids
is
None
or
input_ids
.
shape
[
0
]
==
1
),
"Batch generation is currently not supported. See https://github.com/huggingface/transformers/issues/3021 for more information."
output_sequences
=
self
.
model
.
generate
(
input_ids
=
input_ids
,
**
generate_kwargs
)
# BS x SL
...
...
@@ -590,18 +595,18 @@ class TextGenerationPipeline(Pipeline):
)
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
record
[
"generated_text"
]
=
(
prompt_
text
+
text
[
len
(
self
.
tokenizer
.
decode
(
input_ids
[
0
],
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
)
)
:
]
)
if
input_ids
is
None
:
prompt_
length
=
0
else
:
prompt_length
=
len
(
self
.
tokenizer
.
decode
(
input_ids
[
0
],
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
)
)
record
[
"generated_text"
]
=
prompt_text
+
text
[
prompt_length
:]
result
.
append
(
record
)
results
+=
[
result
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment