"include/vscode:/vscode.git/clone" did not exist on "522d8b2f6d36b753f209846609555b536ec83166"
Unverified Commit cf08830c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Pipeline, Generation] tf generation pipeline bug (#4217)

* fix PR

* move tests to correct place
parent 8bf73126
......@@ -570,6 +570,7 @@ class TextGenerationPipeline(Pipeline):
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
......@@ -581,9 +582,30 @@ class TextGenerationPipeline(Pipeline):
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
with people, even a bishop, begging for his blessing. <eod> </s> <eos>"""
ALLOWED_MODELS = [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"ReformerModelWithLMHead",
"GPT2LMHeadModel",
"OpenAIGPTLMHeadModel",
"CTRLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
"TFGPT2LMHeadModel",
"TFOpenAIGPTLMHeadModel",
"TFCTRLLMHeadModel",
]
def __call__(
self, *args, return_tensors=False, return_text=True, clean_up_tokenization_spaces=False, **generate_kwargs
):
if self.model.__class__.__name__ not in self.ALLOWED_MODELS:
raise NotImplementedError(
"Generation is currently not supported for {}. Please select a model from {} for generation.".format(
self.model.__class__.__name__, self.ALLOWED_MODELS
)
)
text_inputs = self._args_parser(*args)
results = []
......@@ -614,7 +636,7 @@ class TextGenerationPipeline(Pipeline):
result = []
for generated_sequence in output_sequences:
generated_sequence = generated_sequence.tolist()
generated_sequence = generated_sequence.numpy().tolist()
record = {}
if return_tensors:
record["generated_token_ids"] = generated_sequence
......
......@@ -65,6 +65,11 @@ TEXT_GENERATION_FINETUNED_MODELS = {
("xlnet-base-cased", "xlnet-base-cased"),
}
TF_TEXT_GENERATION_FINETUNED_MODELS = {
("gpt2", "gpt2"),
("xlnet-base-cased", "xlnet-base-cased"),
}
FILL_MASK_FINETUNED_MODELS = [
(("distilroberta-base", {"use_fast": False}), "distilroberta-base", None),
]
......@@ -380,6 +385,16 @@ class MonoColumnInputTestCase(unittest.TestCase):
nlp, valid_inputs, invalid_inputs, {},
)
@require_tf
def test_tf_text_generation(self):
valid_inputs = ["A string like this", ["list of strings entry 1", "list of strings v2"]]
invalid_inputs = [None]
for model, tokenizer in TF_TEXT_GENERATION_FINETUNED_MODELS:
nlp = pipeline(task="text-generation", model=model, tokenizer=tokenizer, framework="tf")
self._test_mono_column_pipeline(
nlp, valid_inputs, invalid_inputs, {},
)
class MultiColumnInputTestCase(unittest.TestCase):
def _test_multicolumn_pipeline(self, nlp, valid_inputs: list, invalid_inputs: list, output_keys: Iterable[str]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment