"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c99fe0386be118bceaab1c85cdb8309eb8cb8208"
Unverified Commit 9d999481 authored by Matt's avatar Matt Committed by GitHub
Browse files

Add correct batched handling for apply_chat_template (#29222)



* Add correct batched handling for apply_chat_template

* Fix warning method

* Add error for incompatible options

* expand tests

* Add a skip for markuplm

* Add skips for other layout models

* Skip for LayoutLMv2

* Slightly update the warning message

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* typo fix

* Update docstring for conversation kwarg

* Update return docstring

* Remove the warning, improve error message

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Remove return_dict=None

* Fix up some merge cruft

* More merge cruft

* Add another skip

* Add another skip

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3c17c529
...@@ -1692,7 +1692,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1692,7 +1692,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
def apply_chat_template( def apply_chat_template(
self, self,
conversation: Union[List[Dict[str, str]], "Conversation"], conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"],
chat_template: Optional[str] = None, chat_template: Optional[str] = None,
add_generation_prompt: bool = False, add_generation_prompt: bool = False,
tokenize: bool = True, tokenize: bool = True,
...@@ -1703,15 +1703,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1703,15 +1703,15 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_dict: bool = False, return_dict: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> Union[str, List[int]]: ) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
""" """
Converts a Conversation object or a list of dictionaries with `"role"` and `"content"` keys to a list of token Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to
determine the format and control tokens to use when converting. When chat_template is None, it will fall back determine the format and control tokens to use when converting. When chat_template is None, it will fall back
to the default_chat_template specified at the class level. to the default_chat_template specified at the class level.
Args: Args:
conversation (Union[List[Dict[str, str]], "Conversation"]): A Conversation object or list of dicts conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]], "Conversation"]): A list of dicts
with "role" and "content" keys, representing the chat history so far. with "role" and "content" keys, representing the chat history so far.
chat_template (str, *optional*): A Jinja template to use for this conversion. If chat_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead. this is not passed, the model's default chat template will be used instead.
...@@ -1735,19 +1735,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1735,19 +1735,22 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
- `'pt'`: Return PyTorch `torch.Tensor` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects. - `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects.
return_dict (`bool`, *optional*, defaults to `False`): return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer. tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template. **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
Returns: Returns:
`List[int]`: A list of token ids representing the tokenized chat so far, including control tokens. This `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
output is ready to pass to the model, either directly or via methods like `generate()`. output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
set, will return a dict of tokenizer outputs instead.
""" """
if hasattr(conversation, "messages"): if return_dict and not tokenize:
# Indicates it's a Conversation object raise ValueError(
conversation = conversation.messages "`return_dict=True` is incompatible with `tokenize=False`, because there is no dict "
"of tokenizer outputs to return."
)
if tokenizer_kwargs is None: if tokenizer_kwargs is None:
tokenizer_kwargs = {} tokenizer_kwargs = {}
...@@ -1779,16 +1782,31 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1779,16 +1782,31 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Compilation function uses a cache to avoid recompiling the same template # Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template) compiled_template = self._compile_jinja_template(chat_template)
if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
):
conversations = conversation
is_batched = True
else:
conversations = [conversation]
is_batched = False
rendered = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
rendered = compiled_template.render( for chat in conversations:
messages=conversation, add_generation_prompt=add_generation_prompt, **template_kwargs if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
rendered_chat = compiled_template.render(
messages=chat, add_generation_prompt=add_generation_prompt, **template_kwargs
) )
rendered.append(rendered_chat)
if not is_batched:
rendered = rendered[0]
if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
if tokenize: if tokenize:
if return_dict: out = self(
return self(
rendered, rendered,
padding=padding, padding=padding,
truncation=truncation, truncation=truncation,
...@@ -1797,16 +1815,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ...@@ -1797,16 +1815,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_tensors=return_tensors, return_tensors=return_tensors,
**tokenizer_kwargs, **tokenizer_kwargs,
) )
if return_dict:
return out
else: else:
return self.encode( return out["input_ids"]
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
else: else:
return rendered return rendered
......
...@@ -195,6 +195,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -195,6 +195,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"] tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
) )
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_wordpiece_tokenizer(self): def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"] vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]
......
...@@ -140,6 +140,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -140,6 +140,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "lower newer" output_text = "lower newer"
return input_text, output_text return input_text, output_text
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map) tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer" text = "lower newer"
......
...@@ -107,6 +107,10 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -107,6 +107,10 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "unwanted, running" output_text = "unwanted, running"
return input_text, output_text return input_text, output_text
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of # override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
# this tokenizer # this tokenizer
def test_save_sentencepiece_tokenizer(self) -> None: def test_save_sentencepiece_tokenizer(self) -> None:
......
...@@ -101,6 +101,10 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -101,6 +101,10 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return questions, nodes, xpaths return questions, nodes, xpaths
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def get_input_output_texts(self, tokenizer): def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running" input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running" output_text = "unwanted, running"
......
...@@ -223,6 +223,10 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -223,6 +223,10 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_ids = rust_tokenizer.encode(sequence) rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids) self.assertListEqual(ids, rust_ids)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_chinese(self): def test_chinese(self):
tokenizer = BasicTokenizer() tokenizer = BasicTokenizer()
......
...@@ -1153,6 +1153,14 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -1153,6 +1153,14 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Assert there is online added_tokens special_tokens # Assert there is online added_tokens special_tokens
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens) self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template(self):
pass
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
@require_torch @require_torch
@slow @slow
def test_torch_encode_plus_sent_to_model(self): def test_torch_encode_plus_sent_to_model(self):
......
...@@ -1104,26 +1104,73 @@ class TokenizerTesterMixin: ...@@ -1104,26 +1104,73 @@ class TokenizerTesterMixin:
for tokenizer in tokenizers: for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"): with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template( output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
) )
self.assertEqual(output, expected_output) # Test we can pass chat_template arg self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True # Check that no error raised when tokenize=True
tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template, tokenize=True) output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
)
dict_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
)
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
tokenizer.chat_template = dummy_template tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False) output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised # Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name: with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name) tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name) tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False) output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading self.assertEqual(output, expected_output) # Test output is the same after reloading
tokenizer.apply_chat_template(dummy_conversation, tokenize=True) # Check that no error raised # Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
@require_jinja
def test_chat_template_batched(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
],
[
{"role": "system", "content": "system message 2"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "assistant message 2"},
],
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=False
)
self.assertEqual(
output,
[
"systemsystem messageuseruser messageassistantassistant message",
"systemsystem message 2useruser message 2assistantassistant message 2",
],
)
one_element_output = tokenizer.apply_chat_template(
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
)
self.assertEqual(
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
) # Assert that list structure is retained even with one element
tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja @require_jinja
def test_chat_template_dict(self): def test_chat_template_dict(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment