"tests/vscode:/vscode.git/clone" did not exist on "97c8199dbb4aa3229a84e5afa73fcfe456114555"
Unverified Commit cf08830c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Pipeline, Generation] tf generation pipeline bug (#4217)

* fix PR

* move tests to correct place
parent 8bf73126
...@@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology # in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered. (except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
...@@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline): ...@@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS = [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"ReformerModelWithLMHead",
"GPT2LMHeadModel",
"OpenAIGPTLMHeadModel",
"CTRLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
"TFGPT2LMHeadModel",
"TFOpenAIGPTLMHeadModel",
"TFCTRLLMHeadModel",
]
def __call__( def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
): ):
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
raise NotImplementedError(
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
self.model.__class__.__name__, self.ALLOWED_MODELS
)
)
text_inputs = self._args_parser(*args) text_inputs = self._args_parser(*args)
results = [] results = []
...@@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
result = [] result = []
for generated_sequence in output_sequences: for generated_sequence in output_sequences:
generated_sequence = generated_sequence.tolist() generated_sequence = generated_sequence.numpy().tolist()
record = {} record = {}
if return_tensors: if return_tensors:
record["generated_token_ids"] = generated_sequence record["generated_token_ids"] = generated_sequence
......
...@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = { ...@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
("xlnet-base-cased", "xlnet-base-cased"), ("xlnet-base-cased", "xlnet-base-cased"),
} }
TF_TEXT_GENERATION_FINETUNED_MODELS = {
("gpt2", "gpt2"),
("xlnet-base-cased", "xlnet-base-cased"),
}
FILL_MASK_FINETUNED_MODELS = [ FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None), (("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
] ]
...@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp, valid_inputs, invalid_inputs, {}, nlp, valid_inputs, invalid_inputs, {},
) )
@require_tf
def test_tf_text_generation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [None]
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, {},
)
class MultiColumnInputTestCase(unittest.TestCase): class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]): def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment