Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cf08830c
Unverified
Commit
cf08830c
authored
May 08, 2020
by
Patrick von Platen
Committed by
GitHub
May 08, 2020
Browse files
[Pipeline, Generation] tf generation pipeline bug (#4217)
* fix PR * move tests to correct place
parent
8bf73126
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
1 deletion
+38
-1
src/transformers/pipelines.py
src/transformers/pipelines.py
+23
-1
tests/test_pipelines.py
tests/test_pipelines.py
+15
-0
No files found.
src/transformers/pipelines.py
View file @
cf08830c
...
@@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
...
@@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT
=
"""In 1991, the remains of Russian Tsar Nicholas II and his family
PADDING_TEXT
=
"""In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
...
@@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
...
@@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS
=
[
"XLNetLMHeadModel"
,
"TransfoXLLMHeadModel"
,
"ReformerModelWithLMHead"
,
"GPT2LMHeadModel"
,
"OpenAIGPTLMHeadModel"
,
"CTRLLMHeadModel"
,
"TFXLNetLMHeadModel"
,
"TFTransfoXLLMHeadModel"
,
"TFGPT2LMHeadModel"
,
"TFOpenAIGPTLMHeadModel"
,
"TFCTRLLMHeadModel"
,
]
def
__call__
(
def
__call__
(
self
,
*
args
,
return_tensors
=
False
,
return_text
=
True
,
clean_up_tokenization_spaces
=
False
,
**
generate_kwargs
self
,
*
args
,
return_tensors
=
False
,
return_text
=
True
,
clean_up_tokenization_spaces
=
False
,
**
generate_kwargs
):
):
if
self
.
model
.
__class__
.
__name__
not
in
self
.
ALLOWED_MODELS
:
raise
NotImplementedError
(
"Generation is currently not supported for {}. Please select a model from {} for generation."
.
format
(
self
.
model
.
__class__
.
__name__
,
self
.
ALLOWED_MODELS
)
)
text_inputs
=
self
.
_args_parser
(
*
args
)
text_inputs
=
self
.
_args_parser
(
*
args
)
results
=
[]
results
=
[]
...
@@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
...
@@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
result
=
[]
result
=
[]
for
generated_sequence
in
output_sequences
:
for
generated_sequence
in
output_sequences
:
generated_sequence
=
generated_sequence
.
tolist
()
generated_sequence
=
generated_sequence
.
numpy
().
tolist
()
record
=
{}
record
=
{}
if
return_tensors
:
if
return_tensors
:
record
[
"generated_token_ids"
]
=
generated_sequence
record
[
"generated_token_ids"
]
=
generated_sequence
...
...
tests/test_pipelines.py
View file @
cf08830c
...
@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
...
@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
(
"xlnet-base-cased"
,
"xlnet-base-cased"
),
(
"xlnet-base-cased"
,
"xlnet-base-cased"
),
}
}
TF_TEXT_GENERATION_FINETUNED_MODELS
=
{
(
"gpt2"
,
"gpt2"
),
(
"xlnet-base-cased"
,
"xlnet-base-cased"
),
}
FILL_MASK_FINETUNED_MODELS
=
[
FILL_MASK_FINETUNED_MODELS
=
[
((
"distilroberta-base"
,
{
"use_fast"
:
False
}),
"distilroberta-base"
,
None
),
((
"distilroberta-base"
,
{
"use_fast"
:
False
}),
"distilroberta-base"
,
None
),
]
]
...
@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
...
@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp
,
valid_inputs
,
invalid_inputs
,
{},
nlp
,
valid_inputs
,
invalid_inputs
,
{},
)
)
@
require_tf
def
test_tf_text_generation
(
self
):
valid_inputs
=
[
"A string like this"
,
[
"list of strings entry 1"
,
"list of strings v2"
]]
invalid_inputs
=
[
None
]
for
model
,
tokenizer
in
TF_TEXT_GENERATION_FINETUNED_MODELS
:
nlp
=
pipeline
(
task
=
"text-generation"
,
model
=
model
,
tokenizer
=
tokenizer
,
framework
=
"tf"
)
self
.
_test_mono_column_pipeline
(
nlp
,
valid_inputs
,
invalid_inputs
,
{},
)
class
MultiColumnInputTestCase
(
unittest
.
TestCase
):
class
MultiColumnInputTestCase
(
unittest
.
TestCase
):
def
_test_multicolumn_pipeline
(
self
,
nlp
,
valid_inputs
:
list
,
invalid_inputs
:
list
,
output_keys
:
Iterable
[
str
]):
def
_test_multicolumn_pipeline
(
self
,
nlp
,
valid_inputs
:
list
,
invalid_inputs
:
list
,
output_keys
:
Iterable
[
str
]):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment