"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9ff672fc4d84db3b077e03ea22e2dafbd5d99fa4"
Unverified Commit 8b46c5bc authored by Matt's avatar Matt Committed by GitHub
Browse files

Add add_generation_prompt argument to apply_chat_template (#26573)

* Add add_generation_prompt argument to apply_chat_template

* Add add_generation_prompt argument to apply_chat_template and update default templates

* Fix typo

* Add generation prompts section to chat templating guide

* Add generation prompts section to chat templating guide

* Minor style fix
parent 03af4c42
...@@ -218,10 +218,11 @@ input formats. Our default template for models that don't have a class-specific ...@@ -218,10 +218,11 @@ input formats. Our default template for models that don't have a class-specific
{% endfor %} {% endfor %}
``` ```
If you like this one, here it is in one-liner form, ready to copy into your code: If you like this one, here it is in one-liner form, ready to copy into your code. The one-liner also includes
handy support for "generation prompts" - see the next section for more!
``` ```
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}" tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %})"
``` ```
This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which This template wraps each message in `<|im_start|>` and `<|im_end|>` tokens, and simply writes the role as a string, which
...@@ -240,6 +241,56 @@ The "user", "system" and "assistant" roles are the standard for chat, and we rec ...@@ -240,6 +241,56 @@ The "user", "system" and "assistant" roles are the standard for chat, and we rec
particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited
to these roles - templating is extremely flexible, and any string can be a role. to these roles - templating is extremely flexible, and any string can be a role.
## What are "generation prompts"?
You may notice that the `apply_chat_template` method has an `add_generation_prompt` argument. This argument tells
the template to add tokens that indicate the start of a bot response. For example, consider the following chat:
```python
messages = [
{"role": "user", "content": "Hi there!"},
{"role": "assistant", "content": "Nice to meet you!"},
{"role": "user", "content": "Can I ask a question?"}
]
```
Here's what this will look like without a generation prompt, using the ChatML template we described above:
```python
>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
"""<|im_start|>user
Hi there!<|im_end|>
<|im_start|>assistant
Nice to meet you!<|im_end|>
<|im_start|>user
Can I ask a question?<|im_end|>
"""
```
And here's what it looks like **with** a generation prompt:
```python
>> tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
"""<|im_start|>user
Hi there!<|im_end|>
<|im_start|>assistant
Nice to meet you!<|im_end|>
<|im_start|>user
Can I ask a question?<|im_end|>
<|im_start|>assistant
"""
```
Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
special kind of text to them! You need to guide them with the appropriate control tokens so they know what they're
supposed to be doing.
Not all models require generation prompts. Some models, like BlenderBot and LLaMA, don't have any
special tokens before bot responses. In these cases, the `add_generation_prompt` argument will have no effect. The exact
effect that `add_generation_prompt` has will depend on the template being used.
## I want to use chat templates! How should I get started? ## I want to use chat templates! How should I get started?
If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using If you have any chat models, you should set their `tokenizer.chat_template` attribute and test it using
......
...@@ -181,7 +181,10 @@ class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer): ...@@ -181,7 +181,10 @@ class GPTNeoXJapaneseTokenizer(PreTrainedTokenizer):
A simple chat template that just adds BOS/EOS tokens around messages while discarding role information. A simple chat template that just adds BOS/EOS tokens around messages while discarding role information.
""" """
return ( return (
"{% for message in messages %}" "{{ bos_token + eos_token + message.content + eos_token }}" "{% endfor %}" "{% for message in messages %}"
"{{ bos_token + eos_token + message.content + eos_token }}"
"{% endfor %}"
"{% if add_generation_prompt %} {{ bos_token + eos_token }} {% endif %}"
) )
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
......
...@@ -262,7 +262,7 @@ class ConversationalPipeline(Pipeline): ...@@ -262,7 +262,7 @@ class ConversationalPipeline(Pipeline):
return outputs return outputs
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]: def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
input_ids = self.tokenizer.apply_chat_template(conversation) input_ids = self.tokenizer.apply_chat_template(conversation, add_generation_prompt=True)
if self.framework == "pt": if self.framework == "pt":
input_ids = torch.LongTensor([input_ids]) input_ids = torch.LongTensor([input_ids])
......
...@@ -1718,6 +1718,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1718,6 +1718,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
self, self,
conversation: Union[List[Dict[str, str]], "Conversation"], conversation: Union[List[Dict[str, str]], "Conversation"],
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True, tokenize: bool = True,
padding: bool = False, padding: bool = False,
truncation: bool = False, truncation: bool = False,
...@@ -1736,6 +1737,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1736,6 +1737,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
with "role" and "content" keys, representing the chat history so far. with "role" and "content" keys, representing the chat history so far.
chat_template (str, *optional*): A Jinja template to use for this conversion. If chat_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead. this is not passed, the model's default chat template will be used instead.
add_generation_prompt (bool, *optional*): Whether to end the prompt with the token(s) that indicate
the start of an assistant message. This is useful when you want to generate a response from the model.
Note that this argument will be passed to the chat template, and so it must be supported in the
template for this argument to have any effect.
tokenize (`bool`, defaults to `True`): tokenize (`bool`, defaults to `True`):
Whether to tokenize the output. If `False`, the output will be a string. Whether to tokenize the output. If `False`, the output will be a string.
padding (`bool`, defaults to `False`): padding (`bool`, defaults to `False`):
...@@ -1773,7 +1778,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1773,7 +1778,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Compilation function uses a cache to avoid recompiling the same template # Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template) compiled_template = self._compile_jinja_template(chat_template)
rendered = compiled_template.render(messages=conversation, **self.special_tokens_map) rendered = compiled_template.render(
messages=conversation, add_generation_prompt=add_generation_prompt, **self.special_tokens_map
)
if padding is True: if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense padding = "max_length" # There's only one sequence here, so "longest" makes no sense
...@@ -1815,6 +1822,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1815,6 +1822,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"{% for message in messages %}" "{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}" "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}" "{% endfor %}"
"{% if add_generation_prompt %}"
"{{ '<|im_start|>assistant\n' }}"
"{% endif %}"
) )
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment