Unverified Commit 74d0eb3f authored by Yoni Gottesman's avatar Yoni Gottesman Committed by GitHub
Browse files

Return assistant generated tokens mask in apply_chat_template (#30650)

return assistant generated tokens mask in apply_chat_template
parent 79877106
......@@ -1697,6 +1697,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
return_assistant_tokens_mask: bool = False,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:
......@@ -1747,6 +1748,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
return_dict (`bool`, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
return_assistant_tokens_mask (`bool`, defaults to `False`):
Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
the mask will contain 1. For user and system tokens, the mask will contain 0.
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
**kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
Returns:
......@@ -1761,6 +1766,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"of tokenizer outputs to return."
)
if return_assistant_tokens_mask and not return_dict:
raise ValueError("`return_assistant_tokens_mask=True` is incompatible with `return_dict=False`")
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
......@@ -1813,6 +1821,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"then to ensure that this model continues working without issues."
)
if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
logger.warning_once(
"return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
)
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
......@@ -1847,11 +1860,23 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
rendered = []
all_generation_indices = []
template_kwargs = {**self.special_tokens_map, **kwargs} # kwargs overwrite special tokens if both are present
for chat in conversations:
if hasattr(chat, "messages"):
# Indicates it's a Conversation object
chat = chat.messages
if return_assistant_tokens_mask:
rendered_chat, generation_indices = self._render_with_assistant_indices(
compiled_template=compiled_template,
messages=chat,
tools=tool_schemas,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
)
all_generation_indices.append(generation_indices)
else:
rendered_chat = compiled_template.render(
messages=chat,
tools=tool_schemas,
......@@ -1875,17 +1900,54 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
**tokenizer_kwargs,
)
if return_dict:
if return_assistant_tokens_mask:
assistant_masks = []
if is_batched or return_tensors:
input_ids = out["input_ids"]
else:
input_ids = [out["input_ids"]]
for i in range(len(input_ids)):
current_mask = [0] * len(input_ids[i])
for assistant_start_char, assistant_end_char in all_generation_indices[i]:
start_token = out.char_to_token(i, assistant_start_char)
end_token = out.char_to_token(i, assistant_end_char - 1)
if start_token is None:
# start_token is out of bounds maybe due to truncation.
break
for token_id in range(start_token, end_token + 1 if end_token else len(input_ids)):
current_mask[token_id] = 1
assistant_masks.append(current_mask)
out["assistant_masks"] = assistant_masks if is_batched else assistant_masks[0]
return out
else:
return out["input_ids"]
else:
return rendered
def _render_with_assistant_indices(
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices
@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")
......@@ -1903,7 +1965,49 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}
def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv
def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices
@contextmanager
def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices
yield
finally:
self._rendered_blocks = None
self._generation_indices = None
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)
......
......@@ -2483,3 +2483,7 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
......@@ -2436,3 +2436,7 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
......@@ -1977,3 +1977,7 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
......@@ -2316,3 +2316,7 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip(reason="The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do")
def test_added_tokens_serialization(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
......@@ -1277,3 +1277,7 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@unittest.skip(reason="Chat is not supported")
def test_chat_template(self):
pass
@unittest.skip("Chat is not supported")
def test_chat_template_return_assistant_tokens_mask(self):
pass
......@@ -1157,6 +1157,10 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_chat_template(self):
pass
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
def test_chat_template_return_assistant_tokens_mask(self):
pass
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
......
......@@ -1153,6 +1153,135 @@ class TokenizerTesterMixin:
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]
# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
assistant_prefix_suffix = [
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
]
for tokenizer, pretrained_name, _ in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if not self.test_rust_tokenizer:
self.skipTest(reason="No fast tokenizer defined")
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
# check batched
output = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
assistant_end = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][0][1])
+ len(assistant_prefix_suffix[i][0][1])
- 1,
)
assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
assistant_end2 = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][1][1])
+ len(assistant_prefix_suffix[i][1][1])
- 1,
)
# assert 1 in first assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
# check not batched
output = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
assistant_end = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
)
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
assistant_end2 = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
)
# assert 1 in assistant indices
self.assertEqual(
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment