Unverified Commit 0e9d44d7 authored by Matt's avatar Matt Committed by GitHub
Browse files

Update docstrings for text generation pipeline (#30343)

* Update docstrings for text generation pipeline

* Fix docstring arg

* Update docstring to explain chat mode

* Fix doctests

* Fix doctests
parent 2d92db84
...@@ -37,10 +37,11 @@ class Chat: ...@@ -37,10 +37,11 @@ class Chat:
class TextGenerationPipeline(Pipeline): class TextGenerationPipeline(Pipeline):
""" """
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts, specified text prompt. When the underlying model is a conversational model, it can also accept one or more chats,
where each dict contains "role" and "content" keys. in which case the pipeline will operate in chat mode and will continue the chat(s) by adding its response(s).
Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
Example: Examples:
```python ```python
>>> from transformers import pipeline >>> from transformers import pipeline
...@@ -53,6 +54,15 @@ class TextGenerationPipeline(Pipeline): ...@@ -53,6 +54,15 @@ class TextGenerationPipeline(Pipeline):
>>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False) >>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
``` ```
```python
>>> from transformers import pipeline
>>> generator = pipeline(model="HuggingFaceH4/zephyr-7b-beta")
>>> # Zephyr-beta is a conversational model, so let's pass it a chat instead of a single string
>>> generator([{"role": "user", "content": "What is the capital of France? Answer in one word."}], do_sample=False, max_new_tokens=2)
[{'generated_text': [{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'Paris'}]}]
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
text generation parameters in [Text generation strategies](../generation_strategies) and [Text text generation parameters in [Text generation strategies](../generation_strategies) and [Text
...@@ -62,8 +72,9 @@ class TextGenerationPipeline(Pipeline): ...@@ -62,8 +72,9 @@ class TextGenerationPipeline(Pipeline):
`"text-generation"`. `"text-generation"`.
The models that this pipeline can use are models that have been trained with an autoregressive language modeling The models that this pipeline can use are models that have been trained with an autoregressive language modeling
objective, which includes the uni-directional models in the library (e.g. openai-community/gpt2). See the list of available models objective. See the list of available [text completion models](https://huggingface.co/models?filter=text-generation)
on [huggingface.co/models](https://huggingface.co/models?filter=text-generation). and the list of [conversational models](https://huggingface.co/models?other=conversational)
on [huggingface.co/models].
""" """
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
...@@ -194,8 +205,11 @@ class TextGenerationPipeline(Pipeline): ...@@ -194,8 +205,11 @@ class TextGenerationPipeline(Pipeline):
Complete the prompt(s) given as inputs. Complete the prompt(s) given as inputs.
Args: Args:
text_inputs (`str` or `List[str]`): text_inputs (`str`, `List[str]`, List[Dict[str, str]], or `List[List[Dict[str, str]]]`):
One or several prompts (or one list of prompts) to complete. One or several prompts (or one list of prompts) to complete. If strings or a list of string are
passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list
of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed,
the model's chat template will be used to format them before passing them to the model.
return_tensors (`bool`, *optional*, defaults to `False`): return_tensors (`bool`, *optional*, defaults to `False`):
Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to Whether or not to return the tensors of predictions (as token indices) in the outputs. If set to
`True`, the decoded text is not returned. `True`, the decoded text is not returned.
...@@ -222,7 +236,7 @@ class TextGenerationPipeline(Pipeline): ...@@ -222,7 +236,7 @@ class TextGenerationPipeline(Pipeline):
corresponding to your framework [here](./model#generative-models)). corresponding to your framework [here](./model#generative-models)).
Return: Return:
A list or a list of list of `dict`: Returns one of the following dictionaries (cannot return a combination A list or a list of lists of `dict`: Returns one of the following dictionaries (cannot return a combination
of both `generated_text` and `generated_token_ids`): of both `generated_text` and `generated_token_ids`):
- **generated_text** (`str`, present when `return_text=True`) -- The generated text. - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment