Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cf08830c
Unverified
Commit
cf08830c
authored
May 08, 2020
by
Patrick von Platen
Committed by
GitHub
May 08, 2020
Browse files
[Pipeline, Generation] tf generation pipeline bug (#4217)
* fix PR * move tests to correct place
parent
8bf73126
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
1 deletion
+38
-1
src/transformers/pipelines.py
src/transformers/pipelines.py
+23
-1
tests/test_pipelines.py
tests/test_pipelines.py
+15
-0
No files found.
src/transformers/pipelines.py
View file @
cf08830c
...
...
@@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT
=
"""In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
...
...
@@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS
=
[
"XLNetLMHeadModel"
,
"TransfoXLLMHeadModel"
,
"ReformerModelWithLMHead"
,
"GPT2LMHeadModel"
,
"OpenAIGPTLMHeadModel"
,
"CTRLLMHeadModel"
,
"TFXLNetLMHeadModel"
,
"TFTransfoXLLMHeadModel"
,
"TFGPT2LMHeadModel"
,
"TFOpenAIGPTLMHeadModel"
,
"TFCTRLLMHeadModel"
,
]
def
__call__
(
self
,
*
args
,
return_tensors
=
False
,
return_text
=
True
,
clean_up_tokenization_spaces
=
False
,
**
generate_kwargs
):
if
self
.
model
.
__class__
.
__name__
not
in
self
.
ALLOWED_MODELS
:
raise
NotImplementedError
(
"Generation is currently not supported for {}. Please select a model from {} for generation."
.
format
(
self
.
model
.
__class__
.
__name__
,
self
.
ALLOWED_MODELS
)
)
text_inputs
=
self
.
_args_parser
(
*
args
)
results
=
[]
...
...
@@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
result
=
[]
for
generated_sequence
in
output_sequences
:
generated_sequence
=
generated_sequence
.
tolist
()
generated_sequence
=
generated_sequence
.
numpy
().
tolist
()
record
=
{}
if
return_tensors
:
record
[
"generated_token_ids"
]
=
generated_sequence
...
...
tests/test_pipelines.py
View file @
cf08830c
...
...
@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
(
"xlnet-base-cased"
,
"xlnet-base-cased"
),
}
TF_TEXT_GENERATION_FINETUNED_MODELS
=
{
(
"gpt2"
,
"gpt2"
),
(
"xlnet-base-cased"
,
"xlnet-base-cased"
),
}
FILL_MASK_FINETUNED_MODELS
=
[
((
"distilroberta-base"
,
{
"use_fast"
:
False
}),
"distilroberta-base"
,
None
),
]
...
...
@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp
,
valid_inputs
,
invalid_inputs
,
{},
)
@
require_tf
def
test_tf_text_generation
(
self
):
valid_inputs
=
[
"A string like this"
,
[
"list of strings entry 1"
,
"list of strings v2"
]]
invalid_inputs
=
[
None
]
for
model
,
tokenizer
in
TF_TEXT_GENERATION_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
"text-generation"
,
model
=
model
,
tokenizer
=
tokenizer
,
framework
=
"tf"
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
{},
)
class
MultiColumnInputTestCase
(
unittest
.
TestCase
):
def
_test_multicolumn_pipeline
(
self
,
nlp
,
valid_inputs
:
list
,
invalid_inputs
:
list
,
output_keys
:
Iterable
[
str
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment