Unverified Commit 866df66f authored by Matt's avatar Matt Committed by GitHub
Browse files

Overhaul Conversation class and prompt templating (#25323)



* First commit while I figure this out

* make fixup

* Remove unused method

* Store prompt attrib

* Fix prompt argument for tests

* Make same changes in fast tokenizer

* Remove global prompts from fast tokenizer too

* stash commit

* stash commit

* Migrate PromptConfig to its True Final Location

* Replace Conversation entirely with the new class

* Import/dependency fixes

* Import/dependency fixes

* Change format for lots of default prompts

* More default prompt fixups

* Revert llama old methods so we can compare

* Fix some default configs

* Fix some default configs

* Fix misspelled kwarg

* Fixes for Blenderbot

* make fixup

* little rebase cleanup

* Add basic documentation

* Quick doc fix

* Truncate docstring for now

* Add handling for the case when messages is a single string

* Quick llama merges

* Update conversational pipeline and tests

* Add a couple of legacy properties for backward compatibility

* More legacy handling

* Add docstring for build_conversation_input_ids

* Restructure PromptConfig

* Let's start T E M P L A T I N G

* Refactor all default configs to use templates instead

* Revert changes to the special token properties since we don't need them anymore

* More class templates

* Make the sandbox even sandier

* Everything replaced with pure templating

* Remove docs for PromptConfig

* Add testing and optional requirement boilerplate

* Fix imports and make fixup

* Fix LLaMA tests and add Conversation docstring

* Finally get LLaMA working with the template system

* Finally get LLaMA working with the template system

* make fixup

* make fixup

* fmt-off for the long lists of test tokens

* Rename method to apply_chat_template for now

* Start on documentation

* Make chat_template a property that reads through to the default if it's not set

* Expand docs

* Expand chat templating doc some more

* trim/lstrip blocks by default and update doc

* Few doc tweaks

* rebase cleanup

* Clarify docstring

* rebase cleanup

* rebase cleanup

* make fixup

* Quick doc edit

* Reformat the standard template to match ChatML

* Re-add PEFT check

* Update docs/source/en/chat_templating.md
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Add apply_chat_template to the tokenizer doc

* make fixup

* Add doc links

* Fix chat links

* Fix chat links

* Explain system messages in the doc

* Add chat template test

* Proper save-loading for chat template attribute

* Add test skips for layout models

* Remove _build_conversation_input_ids, add default_chat_template to code_llama

* Make sure all LLaMA models are using the latest template

* Remove default_system_prompt block in code_llama because it has no default prompt

* Update ConversationPipeline preprocess

* Add correct #Copied from links to the default_chat_templates

* Remove unneeded type checking line

* Add a dummy mark_processsed method

* Reorganize Conversation to have **deprecated_kwargs

* Update chat_templating.md

* Quick fix to LLAMA tests

* Small doc tweaks

* Add proper docstrings and "copied from" statements to all default chat templates

* Merge use_default_system_prompt support for code_llama too

* Improve clarity around self.chat_template

* Docstring fix

* Fix blenderbot default template

* More doctest fix

* Break out some tokenizer kwargs

* Update doc to explain default templates

* Quick tweaks to tokenizer args

* Cleanups for tokenizer args

* Add note about cacheing

* Quick tweak to the chat-templating doc

* Update the LLaMA template with error checking and correct system message embedding

* make fixup

* make fixup

* add requires_jinja

* Cleanup to expected output formatting

* Add cacheing

* Fix typo in llama default template

* Update LLaMA tests

* Update documentation

* Improved legacy handling in the Conversation class

* Update Jinja template with proper error handling

* Quick bugfix

* Proper exception raising

* Change cacheing behaviour so it doesn't try to pickle an entire Jinja env

* make fixup

* rebase cleanup

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7c63e6fc
import uuid import uuid
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Union
from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, Pipeline
...@@ -19,137 +19,153 @@ class Conversation: ...@@ -19,137 +19,153 @@ class Conversation:
""" """
Utility class containing a conversation and its history. This class is meant to be used as an input to the Utility class containing a conversation and its history. This class is meant to be used as an input to the
[`ConversationalPipeline`]. The conversation contains several utility functions to manage the addition of new user [`ConversationalPipeline`]. The conversation contains several utility functions to manage the addition of new user
inputs and generated model responses. A conversation needs to contain an unprocessed user input before being passed inputs and generated model responses.
to the [`ConversationalPipeline`]. This user input is either created when the class is instantiated, or by calling
`conversational_pipeline.append_response("input")` after a conversation turn.
Arguments: Arguments:
text (`str`, *optional*): messages (Union[str, List[Dict[str, str]]], *optional*):
The initial user input to start the conversation. If not provided, a user input needs to be provided The initial messages to start the conversation, either a string, or a list of dicts containing "role" and
manually using the [`~Conversation.add_user_input`] method before the conversation can begin. "content" keys. If a string is passed, it is interpreted as a single message with the "user" role.
conversation_id (`uuid.UUID`, *optional*): conversation_id (`uuid.UUID`, *optional*):
Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the Unique identifier for the conversation. If not provided, a random UUID4 id will be assigned to the
conversation. conversation.
past_user_inputs (`List[str]`, *optional*):
Eventual past history of the conversation of the user. You don't need to pass it manually if you use the
pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and
`generated_responses` with equal length lists of strings
generated_responses (`List[str]`, *optional*):
Eventual past history of the conversation of the model. You don't need to pass it manually if you use the
pipeline interactively but if you want to recreate history you need to set both `past_user_inputs` and
`generated_responses` with equal length lists of strings
Usage: Usage:
```python ```python
conversation = Conversation("Going to the movies tonight - any suggestions?") conversation = Conversation("Going to the movies tonight - any suggestions?")
conversation.add_message({"role": "assistant", "content": "The Big lebowski."})
# Steps usually performed by the model when generating a response: conversation.add_message({"role": "user", "content": "Is it good?"})
# 1. Mark the user input as processed (moved to the history)
conversation.mark_processed()
# 2. Append a mode response
conversation.append_response("The Big lebowski.")
conversation.add_user_input("Is it good?")
```""" ```"""
def __init__( def __init__(
self, text: str = None, conversation_id: uuid.UUID = None, past_user_inputs=None, generated_responses=None self, messages: Union[str, List[Dict[str, str]]] = None, conversation_id: uuid.UUID = None, **deprecated_kwargs
): ):
if not conversation_id: if not conversation_id:
conversation_id = uuid.uuid4() conversation_id = uuid.uuid4()
if past_user_inputs is None:
past_user_inputs = []
if generated_responses is None:
generated_responses = []
self.uuid: uuid.UUID = conversation_id if messages is None:
self.past_user_inputs: List[str] = past_user_inputs text = deprecated_kwargs.pop("text", None)
self.generated_responses: List[str] = generated_responses if text is not None:
self.new_user_input: Optional[str] = text messages = [{"role": "user", "content": text}]
else:
messages = []
elif isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# This block deals with the legacy args - new code should just totally
# avoid past_user_inputs and generated_responses
generated_responses = deprecated_kwargs.pop("generated_responses", None)
past_user_inputs = deprecated_kwargs.pop("past_user_inputs", None)
if generated_responses is not None and past_user_inputs is None:
raise ValueError("generated_responses cannot be passed without past_user_inputs!")
if past_user_inputs is not None:
legacy_messages = []
if generated_responses is None:
generated_responses = []
# We structure it this way instead of using zip() because the lengths may differ by 1
for i in range(max([len(past_user_inputs), len(generated_responses)])):
if i < len(past_user_inputs):
legacy_messages.append({"role": "user", "content": past_user_inputs[i]})
if i < len(generated_responses):
legacy_messages.append({"role": "assistant", "content": generated_responses[i]})
messages = legacy_messages + messages
self.uuid = conversation_id
self.messages = messages
def __eq__(self, other): def __eq__(self, other):
if not isinstance(other, Conversation): if not isinstance(other, Conversation):
return False return False
if self.uuid == other.uuid: return self.uuid == other.uuid or self.messages == other.messages
return True
return ( def add_message(self, message: Dict[str, str]):
self.new_user_input == other.new_user_input if not set(message.keys()) == {"role", "content"}:
and self.past_user_inputs == other.past_user_inputs raise ValueError("Message should contain only 'role' and 'content' keys!")
and self.generated_responses == other.generated_responses if message["role"] not in ("user", "assistant", "system"):
) raise ValueError("Only 'user', 'assistant' and 'system' roles are supported for now!")
self.messages.append(message)
def add_user_input(self, text: str, overwrite: bool = False): def add_user_input(self, text: str, overwrite: bool = False):
""" """
Add a user input to the conversation for the next round. This populates the internal `new_user_input` field. Add a user input to the conversation for the next round. This is a legacy method that assumes that inputs must
alternate user/assistant/user/assistant, and so will not add multiple user messages in succession. We recommend
Args: just using `add_message` with role "user" instead.
text (`str`): The user input for the next conversation round.
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not existing and unprocessed user input should be overwritten when this function is called.
""" """
if self.new_user_input: if len(self) > 0 and self[-1]["role"] == "user":
if overwrite: if overwrite:
logger.warning( logger.warning(
f'User input added while unprocessed input was existing: "{self.new_user_input}" was overwritten ' f'User input added while unprocessed input was existing: "{self[-1]["content"]}" was overwritten '
f'with: "{text}".' f'with: "{text}".'
) )
self.new_user_input = text self[-1]["content"] = text
else: else:
logger.warning( logger.warning(
f'User input added while unprocessed input was existing: "{self.new_user_input}" new input ' f'User input added while unprocessed input was existing: "{self[-1]["content"]}" new input '
f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input' f'ignored: "{text}". Set `overwrite` to True to overwrite unprocessed user input'
) )
else: else:
self.new_user_input = text self.messages.append({"role": "user", "content": text})
def mark_processed(self): def append_response(self, response: str):
""" """
Mark the conversation as processed (moves the content of `new_user_input` to `past_user_inputs`) and empties This is a legacy method. We recommend just using `add_message` with an appropriate role instead.
the `new_user_input` field.
""" """
if self.new_user_input: self.messages.append({"role": "assistant", "content": response})
self.past_user_inputs.append(self.new_user_input)
self.new_user_input = None
def append_response(self, response: str): def mark_processed(self):
""" """
Append a response to the list of generated responses. This is a legacy method that no longer has any effect, as the Conversation no longer distinguishes between
processed and unprocessed user input.
Args:
response (`str`): The model generated response.
""" """
self.generated_responses.append(response) pass
def iter_texts(self): def __iter__(self):
""" for message in self.messages:
Iterates over all blobs of the conversation. yield message
Returns: Iterator of (is_user, text_chunk) in chronological order of the conversation. `is_user` is a `bool`, def __getitem__(self, item):
`text_chunks` is a `str`. return self.messages[item]
"""
for user_input, generated_response in zip(self.past_user_inputs, self.generated_responses): def __setitem__(self, key, value):
yield True, user_input self.messages[key] = value
yield False, generated_response
if self.new_user_input: def __len__(self):
yield True, self.new_user_input return len(self.messages)
def __repr__(self): def __repr__(self):
""" """
Generates a string representation of the conversation. Generates a string representation of the conversation.
Return: Returns:
`str`: `str`:
Example: Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user >> Going to the movies tonight - any Example:
suggestions? bot >> The Big Lebowski Conversation id: 7d15686b-dc94-49f2-9c4b-c9eac6a1f114 user: Going to the movies tonight - any suggestions?
bot: The Big Lebowski
""" """
output = f"Conversation id: {self.uuid} \n" output = f"Conversation id: {self.uuid}\n"
for is_user, text in self.iter_texts(): for message in self.messages:
name = "user" if is_user else "bot" output += f"{message['role']}: {message['content']}\n"
output += f"{name} >> {text} \n"
return output return output
def iter_texts(self):
# This is a legacy method for backwards compatibility. It is recommended to just directly access
# conversation.messages instead.
for message in self.messages:
yield message["role"] == "user", message["content"]
@property
def past_user_inputs(self):
# This is a legacy property for backwards compatibility. It is recommended to just directly access
# conversation.messages instead.
return [message["content"] for message in self.messages if message["role"] == "user"]
@property
def generated_responses(self):
# This is a legacy property for backwards compatibility. It is recommended to just directly access
# conversation.messages instead.
return [message["content"] for message in self.messages if message["role"] == "assistant"]
@add_end_docstrings( @add_end_docstrings(
PIPELINE_INIT_ARGS, PIPELINE_INIT_ARGS,
...@@ -246,18 +262,7 @@ class ConversationalPipeline(Pipeline): ...@@ -246,18 +262,7 @@ class ConversationalPipeline(Pipeline):
return outputs return outputs
def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]: def preprocess(self, conversation: Conversation, min_length_for_response=32) -> Dict[str, Any]:
if not isinstance(conversation, Conversation): input_ids = self.tokenizer.apply_chat_template(conversation)
raise ValueError("ConversationalPipeline, expects Conversation as inputs")
if conversation.new_user_input is None:
raise ValueError(
f"Conversation with UUID {type(conversation.uuid)} does not contain new user input to process. "
"Add user inputs with the conversation's `add_user_input` method"
)
if hasattr(self.tokenizer, "_build_conversation_input_ids"):
input_ids = self.tokenizer._build_conversation_input_ids(conversation)
else:
# If the tokenizer cannot handle conversations, we default to only the old version
input_ids = self._legacy_parse_and_tokenize(conversation)
if self.framework == "pt": if self.framework == "pt":
input_ids = torch.LongTensor([input_ids]) input_ids = torch.LongTensor([input_ids])
...@@ -292,19 +297,5 @@ class ConversationalPipeline(Pipeline): ...@@ -292,19 +297,5 @@ class ConversationalPipeline(Pipeline):
clean_up_tokenization_spaces=clean_up_tokenization_spaces, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
) )
conversation = model_outputs["conversation"] conversation = model_outputs["conversation"]
conversation.mark_processed() conversation.add_message({"role": "assistant", "content": answer})
conversation.append_response(answer)
return conversation return conversation
def _legacy_parse_and_tokenize(self, conversation: Conversation) -> Dict:
eos_token_id = self.tokenizer.eos_token_id
input_ids = []
for is_user, text in conversation.iter_texts():
if eos_token_id is not None:
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False) + [eos_token_id])
else:
input_ids.extend(self.tokenizer.encode(text, add_special_tokens=False))
if len(input_ids) > self.tokenizer.model_max_length:
input_ids = input_ids[-self.tokenizer.model_max_length :]
return input_ids
...@@ -64,6 +64,7 @@ from .utils import ( ...@@ -64,6 +64,7 @@ from .utils import (
is_ftfy_available, is_ftfy_available,
is_ipex_available, is_ipex_available,
is_jieba_available, is_jieba_available,
is_jinja_available,
is_jumanpp_available, is_jumanpp_available,
is_keras_nlp_available, is_keras_nlp_available,
is_librosa_available, is_librosa_available,
...@@ -336,6 +337,13 @@ def require_jieba(test_case): ...@@ -336,6 +337,13 @@ def require_jieba(test_case):
return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case) return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)
def require_jinja(test_case):
"""
Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
"""
return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
def require_tf2onnx(test_case): def require_tf2onnx(test_case):
return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case) return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
......
...@@ -27,6 +27,7 @@ from collections import OrderedDict, UserDict ...@@ -27,6 +27,7 @@ from collections import OrderedDict, UserDict
from collections.abc import Mapping, Sized from collections.abc import Mapping, Sized
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
...@@ -69,6 +70,7 @@ if TYPE_CHECKING: ...@@ -69,6 +70,7 @@ if TYPE_CHECKING:
import tensorflow as tf import tensorflow as tf
if is_flax_available(): if is_flax_available():
import jax.numpy as jnp # noqa: F401 import jax.numpy as jnp # noqa: F401
from .pipelines.conversational import Conversation
if is_tokenizers_available(): if is_tokenizers_available():
...@@ -1426,6 +1428,7 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r""" ...@@ -1426,6 +1428,7 @@ ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING = r"""
- **length** -- The length of the inputs (when `return_length=True`) - **length** -- The length of the inputs (when `return_length=True`)
""" """
INIT_TOKENIZER_DOCSTRING = r""" INIT_TOKENIZER_DOCSTRING = r"""
Class attributes (overridden by derived classes) Class attributes (overridden by derived classes)
...@@ -1461,6 +1464,9 @@ INIT_TOKENIZER_DOCSTRING = r""" ...@@ -1461,6 +1464,9 @@ INIT_TOKENIZER_DOCSTRING = r"""
truncation_side (`str`, *optional*): truncation_side (`str`, *optional*):
The side on which the model should have truncation applied. Should be selected between ['right', 'left']. The side on which the model should have truncation applied. Should be selected between ['right', 'left'].
Default value is picked from the class attribute of the same name. Default value is picked from the class attribute of the same name.
chat_template (`str`, *optional*):
A Jinja template string that will be used to format lists of chat messages. See
https://huggingface.co/docs/transformers/chat_templating for a full description.
model_input_names (`List[string]`, *optional*): model_input_names (`List[string]`, *optional*):
The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or The list of inputs accepted by the forward pass of the model (like `"token_type_ids"` or
`"attention_mask"`). Default value is picked from the class attribute of the same name. `"attention_mask"`). Default value is picked from the class attribute of the same name.
...@@ -1558,6 +1564,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1558,6 +1564,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
{} {}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging). ) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
self._in_target_context_manager = False self._in_target_context_manager = False
# Stores a Jinja template that formats chat histories into tokenizable strings
self.chat_template = kwargs.pop("chat_template", None)
super().__init__(**kwargs) super().__init__(**kwargs)
@property @property
...@@ -1627,6 +1637,109 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1627,6 +1637,109 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
""" """
raise NotImplementedError() raise NotImplementedError()
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], "Conversation"],
chat_template: Optional[str] = None,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
**tokenizer_kwargs,
) -> Union[str, List[int]]:
"""
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
determine the format and control tokens to use when converting. When chat_template is None, it will fall back
to the default_chat_template specified at the class level.
Args:
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
with "role" and "content" keys, representing the chat history so far.
chat_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead.
tokenize (`bool`, defaults to `True`):
Whether to tokenize the output. If `False`, the output will be a string.
padding (`bool`, defaults to `False`):
Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
truncation (`bool`, defaults to `False`):
Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
max_length (`int`, *optional*):
Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
not specified, the tokenizer's `max_length` attribute will be used as a default.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
values are:
- `'tf'`: Return TensorFlow `tf.Tensor` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
**tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
Returns:
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
output is ready to pass to the model, either directly or via methods like `generate()`.
"""
if hasattr(conversation, "messages"):
# Indicates it's a Conversation object
conversation = conversation.messages
# priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template`
if chat_template is None:
if self.chat_template is not None:
chat_template = self.chat_template
else:
chat_template = self.default_chat_template
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
rendered = compiled_template.render(messages=conversation, **self.special_tokens_map)
if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
if tokenize:
return self.encode(
rendered,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
else:
return rendered
@lru_cache
def _compile_jinja_template(self, chat_template):
try:
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
def raise_exception(message):
raise TemplateError(message)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
@property
def default_chat_template(self):
"""
This template formats inputs in the standard ChatML format. See
https://github.com/openai/openai-python/blob/main/chatml.md
"""
return (
"{% for message in messages %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% endfor %}"
)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -2187,6 +2300,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -2187,6 +2300,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if hasattr(self, k): if hasattr(self, k):
tokenizer_config[k] = getattr(self, k) tokenizer_config[k] = getattr(self, k)
if self.chat_template is not None:
tokenizer_config["chat_template"] = self.chat_template
if len(self.init_inputs) > 0: if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs) tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys(): for file_id in self.vocab_files_names.keys():
......
...@@ -119,6 +119,7 @@ from .import_utils import ( ...@@ -119,6 +119,7 @@ from .import_utils import (
is_in_notebook, is_in_notebook,
is_ipex_available, is_ipex_available,
is_jieba_available, is_jieba_available,
is_jinja_available,
is_jumanpp_available, is_jumanpp_available,
is_kenlm_available, is_kenlm_available,
is_keras_nlp_available, is_keras_nlp_available,
......
...@@ -91,6 +91,7 @@ except importlib.metadata.PackageNotFoundError: ...@@ -91,6 +91,7 @@ except importlib.metadata.PackageNotFoundError:
_ftfy_available = _is_package_available("ftfy") _ftfy_available = _is_package_available("ftfy")
_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) _ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True)
_jieba_available = _is_package_available("jieba") _jieba_available = _is_package_available("jieba")
_jinja_available = _is_package_available("jinja2")
_kenlm_available = _is_package_available("kenlm") _kenlm_available = _is_package_available("kenlm")
_keras_nlp_available = _is_package_available("keras_nlp") _keras_nlp_available = _is_package_available("keras_nlp")
_librosa_available = _is_package_available("librosa") _librosa_available = _is_package_available("librosa")
...@@ -793,6 +794,10 @@ def is_jieba_available(): ...@@ -793,6 +794,10 @@ def is_jieba_available():
return _jieba_available return _jieba_available
def is_jinja_available():
return _jinja_available
# docstyle-ignore # docstyle-ignore
DATASETS_IMPORT_ERROR = """ DATASETS_IMPORT_ERROR = """
{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with: {0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with:
...@@ -1081,6 +1086,11 @@ PEFT_IMPORT_ERROR = """ ...@@ -1081,6 +1086,11 @@ PEFT_IMPORT_ERROR = """
peft`. Please note that you may need to restart your runtime after installation. peft`. Please note that you may need to restart your runtime after installation.
""" """
JINJA_IMPORT_ERROR = """
{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install
jinja2`. Please note that you may need to restart your runtime after installation.
"""
BACKENDS_MAPPING = OrderedDict( BACKENDS_MAPPING = OrderedDict(
[ [
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
...@@ -1118,6 +1128,7 @@ BACKENDS_MAPPING = OrderedDict( ...@@ -1118,6 +1128,7 @@ BACKENDS_MAPPING = OrderedDict(
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)), ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)), ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
("peft", (is_peft_available, PEFT_IMPORT_ERROR)), ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)),
] ]
) )
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import BlenderbotTokenizer, BlenderbotTokenizerFast from transformers import BlenderbotTokenizer, BlenderbotTokenizerFast
from transformers.testing_utils import require_jinja
from transformers.utils import cached_property from transformers.utils import cached_property
...@@ -50,3 +51,24 @@ class Blenderbot3BTokenizerTests(unittest.TestCase): ...@@ -50,3 +51,24 @@ class Blenderbot3BTokenizerTests(unittest.TestCase):
def test_3B_tokenization_same_as_parlai_rust_tokenizer(self): def test_3B_tokenization_same_as_parlai_rust_tokenizer(self):
assert self.rust_tokenizer_3b.add_prefix_space assert self.rust_tokenizer_3b.add_prefix_space
assert self.rust_tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]] assert self.rust_tokenizer_3b([" Sam", "Sam"]).input_ids == [[5502, 2], [5502, 2]]
@require_jinja
def test_tokenization_for_chat(self):
tok = self.tokenizer_3b
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tok.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 2],
[553, 366, 265, 4792, 3879, 73, 311, 21, 228, 228, 6950, 8, 228, 3490, 287, 2273, 304, 21, 2],
[3490, 287, 2273, 304, 21, 228, 228, 6950, 8, 2],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
from datasets import load_dataset from datasets import load_dataset
from transformers import BloomTokenizerFast from transformers import BloomTokenizerFast
from transformers.testing_utils import require_tokenizers from transformers.testing_utils import require_jinja, require_tokenizers
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -134,6 +134,27 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -134,6 +134,27 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1) self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1) self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = self.get_rust_tokenizer()
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2],
[5448, 1306, 267, 66799, 44799, 37143, 17, 2, 59414, 4, 2, 229126, 427, 11890, 1152, 17, 2],
[229126, 427, 11890, 1152, 17, 2, 59414, 4, 2],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
def test_add_prefix_space_fast(self): def test_add_prefix_space_fast(self):
tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True) tokenizer_w_prefix = self.get_rust_tokenizer(add_prefix_space=True)
tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False) tokenizer_wo_prefix = self.get_rust_tokenizer(add_prefix_space=False)
......
...@@ -20,7 +20,7 @@ import unittest ...@@ -20,7 +20,7 @@ import unittest
from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast from transformers import AutoTokenizer, GPT2Tokenizer, GPT2TokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES from transformers.models.gpt2.tokenization_gpt2 import VOCAB_FILES_NAMES
from transformers.testing_utils import require_tokenizers from transformers.testing_utils import require_jinja, require_tokenizers
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -275,6 +275,27 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -275,6 +275,27 @@ class GPT2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence) self.assertEqual(encoded_sequence, filtered_sequence)
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname)
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20],
[20, 1, 20, 10, 20, 4, 3, 10, 20, 10, 20, 3, 0, 20, 20, 20, 0, 10, 20, 20, 20, 6, 20, 1, 6, 20, 20, 20, 3, 0, 0, 1, 20, 20, 20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20],
[20, 7, 20, 3, 10, 6, 1, 10, 20, 3, 3, 6, 10, 20, 1, 20, 20, 20, 20, 3, 0, 0, 1, 20, 20]]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
@require_tokenizers @require_tokenizers
class OPTTokenizationTest(unittest.TestCase): class OPTTokenizationTest(unittest.TestCase):
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import unittest import unittest
from transformers import GPTSw3Tokenizer from transformers import GPTSw3Tokenizer
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow from transformers.testing_utils import get_tests_dir, require_jinja, require_sentencepiece, require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -128,3 +128,27 @@ class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -128,3 +128,27 @@ class GPTSw3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
model_name="AI-Sweden/gpt-sw3-126m", model_name="AI-Sweden/gpt-sw3-126m",
sequences=sequences, sequences=sequences,
) )
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = GPTSw3Tokenizer(SAMPLE_VOCAB)
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ],
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 530, 339, 265, 878, 708, 727, 275, 347, 541, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 938, 541, 419, ],
[268, 63, 127, 462, 276, 294, 348, 536, 797, 275, 127, 65, 63, 263, 65, 938, 541, 419, 984, 429, 281, 264, 1261, 291, 260, 63, 263, 65, 1256, 263, 314, 419, 366, 354, 294, 360, 63, 263, 65, 938, 541, 419, ]
]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
...@@ -22,7 +22,7 @@ from transformers.models.gptsan_japanese.tokenization_gptsan_japanese import ( ...@@ -22,7 +22,7 @@ from transformers.models.gptsan_japanese.tokenization_gptsan_japanese import (
VOCAB_FILES_NAMES, VOCAB_FILES_NAMES,
GPTSanJapaneseTokenizer, GPTSanJapaneseTokenizer,
) )
from transformers.testing_utils import require_tokenizers, slow from transformers.testing_utils import require_jinja, require_tokenizers, slow
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -193,3 +193,27 @@ class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -193,3 +193,27 @@ class GPTSanJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_padding_different_model_input_name(self): def test_padding_different_model_input_name(self):
# tokenizer has no padding token # tokenizer has no padding token
pass pass
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = self.tokenizer_class.from_pretrained("Tanrei/GPTSAN-japanese")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
[35993, 35998, 35637, 35659, 35665, 35716, 35645, 35662, 35649, 35716, 35645, 35716, 35652, 35649, 35656, 35660, 35650, 35665, 35656, 35716, 35647, 35652, 35645, 35664, 35646, 35659, 35664, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999, 35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999],
[35993, 35998, 35626, 35653, 35647, 35649, 35716, 35664, 35659, 35716, 35657, 35649, 35649, 35664, 35716, 35669, 35659, 35665, 35595, 35999, 35993, 35998, 35620, 35649, 35656, 35656, 35659, 35582, 35999],
]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
...@@ -2486,3 +2486,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -2486,3 +2486,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip("Doesn't support another framework than PyTorch") @unittest.skip("Doesn't support another framework than PyTorch")
def test_np_encode_plus_sent_to_model(self): def test_np_encode_plus_sent_to_model(self):
pass pass
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass
...@@ -2439,3 +2439,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -2439,3 +2439,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# This should not fail # This should not fail
model(encoded_sequence) model(encoded_sequence)
model(batch_encoded_sequence) model(batch_encoded_sequence)
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass
...@@ -1958,3 +1958,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -1958,3 +1958,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip("Doesn't use SentencePiece") @unittest.skip("Doesn't use SentencePiece")
def test_sentencepiece_tokenize_and_decode(self): def test_sentencepiece_tokenize_and_decode(self):
pass pass
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass
...@@ -32,6 +32,7 @@ from transformers.convert_slow_tokenizer import convert_slow_tokenizer ...@@ -32,6 +32,7 @@ from transformers.convert_slow_tokenizer import convert_slow_tokenizer
from transformers.testing_utils import ( from transformers.testing_utils import (
get_tests_dir, get_tests_dir,
nested_simplify, nested_simplify,
require_jinja,
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
...@@ -574,6 +575,32 @@ class LlamaIntegrationTest(unittest.TestCase): ...@@ -574,6 +575,32 @@ class LlamaIntegrationTest(unittest.TestCase):
# a dummy prefix space is not added by the sp_model as it was de-activated # a dummy prefix space is not added by the sp_model as it was de-activated
self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str)) self.assertEqual(tokens, tokenizer.sp_model.encode("▁▁▁", out_type=str))
@require_jinja
def test_tokenization_for_chat(self):
tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b", legacy=False)
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "user", "content": "Hello!"}],
]
# Matt: The third test case tests the default system message, but if this is ever changed in the
# class/repo code then that test will fail, and the case will need to be updated.
tokenized_chats = [tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
# fmt: off
expected_tokens = [
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962],
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 13563, 7451, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962, 20103, 304, 5870, 366, 29889, 29871, 2],
[1, 29961, 25580, 29962, 3532, 14816, 29903, 6778, 13, 3492, 526, 263, 8444, 29892, 3390, 1319, 322, 15993, 20255, 29889, 29849, 1234, 408, 1371, 3730, 408, 1950, 29892, 1550, 1641, 9109, 29889, 3575, 6089, 881, 451, 3160, 738, 10311, 1319, 29892, 443, 621, 936, 29892, 11021, 391, 29892, 7916, 391, 29892, 304, 27375, 29892, 18215, 29892, 470, 27302, 2793, 29889, 3529, 9801, 393, 596, 20890, 526, 5374, 635, 443, 5365, 1463, 322, 6374, 297, 5469, 29889, 13, 13, 3644, 263, 1139, 947, 451, 1207, 738, 4060, 29892, 470, 338, 451, 2114, 1474, 16165, 261, 296, 29892, 5649, 2020, 2012, 310, 22862, 1554, 451, 1959, 29889, 960, 366, 1016, 29915, 29873, 1073, 278, 1234, 304, 263, 1139, 29892, 3113, 1016, 29915, 29873, 6232, 2089, 2472, 29889, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10994, 29991, 518, 29914, 25580, 29962]
]
# fmt: on
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
......
...@@ -2311,3 +2311,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -2311,3 +2311,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"Dummy warning", "Dummy warning",
cm.records[0].message, cm.records[0].message,
) )
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass
...@@ -1274,3 +1274,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -1274,3 +1274,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip("Doesn't support another framework than PyTorch") @unittest.skip("Doesn't support another framework than PyTorch")
def test_np_encode_plus_sent_to_model(self): def test_np_encode_plus_sent_to_model(self):
pass pass
@unittest.skip("Chat is not supported")
def test_chat_template(self):
pass
...@@ -16,7 +16,7 @@ import unittest ...@@ -16,7 +16,7 @@ import unittest
from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast from transformers.models.whisper import WhisperTokenizer, WhisperTokenizerFast
from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence from transformers.models.whisper.tokenization_whisper import _combine_tokens_into_words, _find_longest_common_sequence
from transformers.testing_utils import slow from transformers.testing_utils import require_jinja, slow
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -473,3 +473,25 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): ...@@ -473,3 +473,25 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"] output = multilingual_tokenizer.decode(INPUT_TOKENS, output_offsets=True)["offsets"]
self.assertEqual(output, []) self.assertEqual(output, [])
@require_jinja
def test_tokenization_for_chat(self):
multilingual_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
# This is in English, but it's just here to make sure the chat control tokens are being added properly
test_chats = [
[{"role": "system", "content": "You are a helpful chatbot."}, {"role": "user", "content": "Hello!"}],
[
{"role": "system", "content": "You are a helpful chatbot."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Nice to meet you."},
],
[{"role": "assistant", "content": "Nice to meet you."}, {"role": "user", "content": "Hello!"}],
]
tokenized_chats = [multilingual_tokenizer.apply_chat_template(test_chat) for test_chat in test_chats]
expected_tokens = [
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257],
[3223, 366, 257, 4961, 5081, 18870, 13, 50257, 15947, 0, 50257, 37717, 220, 1353, 1677, 291, 13, 50257],
[37717, 220, 1353, 1677, 291, 13, 50257, 15947, 0, 50257],
]
for tokenized_chat, expected_tokens in zip(tokenized_chats, expected_tokens):
self.assertListEqual(tokenized_chat, expected_tokens)
...@@ -78,17 +78,23 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -78,17 +78,23 @@ class ConversationalPipelineTests(unittest.TestCase):
def run_pipeline_test(self, conversation_agent, _): def run_pipeline_test(self, conversation_agent, _):
# Simple # Simple
outputs = conversation_agent(Conversation("Hi there!")) outputs = conversation_agent(Conversation("Hi there!"))
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)])) self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
)
# Single list # Single list
outputs = conversation_agent([Conversation("Hi there!")]) outputs = conversation_agent([Conversation("Hi there!")])
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)])) self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
)
# Batch # Batch
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?") conversation_2 = Conversation("What's the last book you have read?")
self.assertEqual(len(conversation_1.past_user_inputs), 0) self.assertEqual(len(conversation_1), 1)
self.assertEqual(len(conversation_2.past_user_inputs), 0) self.assertEqual(len(conversation_2), 1)
outputs = conversation_agent([conversation_1, conversation_2]) outputs = conversation_agent([conversation_1, conversation_2])
self.assertEqual(outputs, [conversation_1, conversation_2]) self.assertEqual(outputs, [conversation_1, conversation_2])
...@@ -96,32 +102,35 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -96,32 +102,35 @@ class ConversationalPipelineTests(unittest.TestCase):
outputs, outputs,
[ [
Conversation( Conversation(
past_user_inputs=["Going to the movies tonight - any suggestions?"], [
generated_responses=[ANY(str)], {"role": "user", "content": "Going to the movies tonight - any suggestions?"},
{"role": "assistant", "content": ANY(str)},
],
),
Conversation(
[
{"role": "user", "content": "What's the last book you have read?"},
{"role": "assistant", "content": ANY(str)},
]
), ),
Conversation(past_user_inputs=["What's the last book you have read?"], generated_responses=[ANY(str)]),
], ],
) )
# One conversation with history # One conversation with history
conversation_2.add_user_input("Why do you recommend it?") conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
outputs = conversation_agent(conversation_2) outputs = conversation_agent(conversation_2)
self.assertEqual(outputs, conversation_2) self.assertEqual(outputs, conversation_2)
self.assertEqual( self.assertEqual(
outputs, outputs,
Conversation( Conversation(
past_user_inputs=["What's the last book you have read?", "Why do you recommend it?"], [
generated_responses=[ANY(str), ANY(str)], {"role": "user", "content": "What's the last book you have read?"},
{"role": "assistant", "content": ANY(str)},
{"role": "user", "content": "Why do you recommend it?"},
{"role": "assistant", "content": ANY(str)},
]
), ),
) )
with self.assertRaises(ValueError):
conversation_agent("Hi there!")
with self.assertRaises(ValueError):
conversation_agent(Conversation())
# Conversation have been consumed and are not valid anymore
# Inactive conversations passed to the pipeline raise a ValueError
with self.assertRaises(ValueError):
conversation_agent(conversation_2)
@require_torch @require_torch
@slow @slow
......
...@@ -50,6 +50,7 @@ from transformers.testing_utils import ( ...@@ -50,6 +50,7 @@ from transformers.testing_utils import (
check_json_file_has_correct_format, check_json_file_has_correct_format,
get_tests_dir, get_tests_dir,
is_pt_tf_cross_test, is_pt_tf_cross_test,
require_jinja,
require_tf, require_tf,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
...@@ -1052,6 +1053,40 @@ class TokenizerTesterMixin: ...@@ -1052,6 +1053,40 @@ class TokenizerTesterMixin:
if tokenizer.num_special_tokens_to_add(pair=True): if tokenizer.num_special_tokens_to_add(pair=True):
self.assertIn(None, output.sequence_ids()) self.assertIn(None, output.sequence_ids())
@require_jinja
def test_chat_template(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
expected_output = "systemsystem messageuseruser messageassistantassistant message"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False
)
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
def test_number_of_added_tokens(self): def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False) tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers: for tokenizer in tokenizers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment