Unverified Commit 9d999481 authored by Matt's avatar Matt Committed by GitHub
Browse files

Add correct batched handling for apply_chat_template (#29222)



* Add correct batched handling for apply_chat_template

* Fix warning method

* Add error for incompatible options

* expand tests

* Add a skip for markuplm

* Add skips for other layout models

* Skip for LayoutLMv2

* Slightly update the warning message

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* typo fix

* Update docstring for conversation kwarg

* Update return docstring

* Remove the warning, improve error message

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Remove return_dict=None

* Fix up some merge cruft

* More merge cruft

* Add another skip

* Add another skip

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3c17c529
......@@ -1692,7 +1692,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]], "Conversation"],
conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
chat_template: Optional[str] = None,
add_generation_prompt: bool = False,
tokenize: bool = True,
......@@ -1703,15 +1703,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_dict: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[str, List[int]]:
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
"""
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token
Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
determine the format and control tokens to use when converting. When chat_template is None, it will fall back
to the default_chat_template specified at the class level.
Args:
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts
conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts
with "role" and "content" keys, representing the chat history so far.
chat_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead.
......@@ -1735,19 +1735,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
return_dict (`bool`, *optional*, defaults to `False`):
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
Returns:
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This
output is ready to pass to the model, either directly or via methods like `generate()`.
`Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
set, will return a dict of tokenizer outputs instead.
"""
if hasattr(conversation, "messages"):
# Indicates it's a Conversation object
conversation = conversation.messages
if return_dict and not tokenize:
raise ValueError(
"`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
"of tokenizer outputs to return."
)
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
......@@ -1779,16 +1782,31 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
):
conversations = conversation
is_batched = True
else:
conversations = [conversation]
is_batched = False
rendered = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
rendered = compiled_template.render(
messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs
for chat in conversations:
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
)
rendered.append(rendered_chat)
if not is_batched:
rendered = rendered[0]
if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
if tokenize:
if return_dict:
return self(
out = self(
rendered,
padding=padding,
truncation=truncation,
......@@ -1797,16 +1815,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_tensors=return_tensors,
**tokenizer_kwargs,
)
if return_dict:
return out
else:
return self.encode(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
return out["input_ids"]
else:
return rendered
......
......@@ -195,6 +195,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
......
......@@ -140,6 +140,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "lower newer"
return input_text, output_text
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
......
......@@ -107,6 +107,10 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "unwanted, running"
return input_text, output_text
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
# this tokenizer
def test_save_sentencepiece_tokenizer(self) -> None:
......
......@@ -101,6 +101,10 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return questions, nodes, xpaths
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"
......
......@@ -223,6 +223,10 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_chinese(self):
tokenizer = BasicTokenizer()
......
......@@ -1153,6 +1153,14 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Assert there is online added_tokens special_tokens
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template(self):
pass
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
@require_torch
@slow
def test_torch_encode_plus_sent_to_model(self):
......
......@@ -1104,26 +1104,73 @@ class TokenizerTesterMixin:
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
)
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True)
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
)
dict_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
)
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False)
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
@require_jinja
def test_chat_template_batched(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
],
[
{"role": "system", "content": "system message 2"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "assistant message 2"},
],
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=False
)
self.assertEqual(
output,
[
"systemsystem messageuseruser messageassistantassistant message",
"systemsystem message 2useruser message 2assistantassistant message 2",
],
)
one_element_output = tokenizer.apply_chat_template(
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
)
self.assertEqual(
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
) # Assert that list structure is retained even with one element
tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja
def test_chat_template_dict(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment