Unverified Commit 536ea2ac authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`LlamaSlowConverter`] Slow to Fast better support (#29797)

* fix

* fix test

* style

* nit

* rather rely on concert token to id

* fix quality

* Update src/transformers/convert_slow_tokenizer.py
parent e2036468
...@@ -1331,9 +1331,9 @@ class LlamaConverter(SpmConverter): ...@@ -1331,9 +1331,9 @@ class LlamaConverter(SpmConverter):
def vocab(self, proto): def vocab(self, proto):
vocab = [ vocab = [
("<unk>", 0.0), (self.original_tokenizer.convert_ids_to_tokens(0), 0.0),
("<s>", 0.0), (self.original_tokenizer.convert_ids_to_tokens(1), 0.0),
("</s>", 0.0), (self.original_tokenizer.convert_ids_to_tokens(2), 0.0),
] ]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
return vocab return vocab
...@@ -1371,9 +1371,9 @@ class LlamaConverter(SpmConverter): ...@@ -1371,9 +1371,9 @@ class LlamaConverter(SpmConverter):
) )
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
[ [
AddedToken("<unk>", normalized=False, special=True), AddedToken(self.original_tokenizer.convert_ids_to_tokens(0), normalized=False, special=True),
AddedToken("<s>", normalized=False, special=True), AddedToken(self.original_tokenizer.convert_ids_to_tokens(1), normalized=False, special=True),
AddedToken("</s>", normalized=False, special=True), AddedToken(self.original_tokenizer.convert_ids_to_tokens(2), normalized=False, special=True),
] ]
) )
else: else:
......
...@@ -22,6 +22,7 @@ import requests ...@@ -22,6 +22,7 @@ import requests
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
AutoTokenizer,
LlavaConfig, LlavaConfig,
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
is_torch_available, is_torch_available,
...@@ -575,3 +576,29 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): ...@@ -575,3 +576,29 @@ class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase):
labels=input_ids, labels=input_ids,
).loss ).loss
loss.backward() loss.backward()
def test_tokenizer_integration(self):
slow_tokenizer = AutoTokenizer.from_pretrained("liuhaotian/llava-v1.6-34b", use_fast=False)
slow_tokenizer.add_tokens("<image>", True)
fast_tokenizer = AutoTokenizer.from_pretrained(
"liuhaotian/llava-v1.6-34b",
bos_token="<|startoftext|>",
eos_token="<|endoftext|>",
from_slow=True,
legacy=False,
)
fast_tokenizer.add_tokens("<image>", True)
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
# If the token is added as special, it's not normalized, and the only diff is the extra space after special tokens.
# https://github.com/huggingface/transformers/pull/28881 is the fix for this.
self.assertEqual(
slow_tokenizer.tokenize(prompt),
['<|im_start|>', 'system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', 'user', '\n', '<image>', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', 'ass', 'istant', '\n']
) # fmt: skip
self.assertEqual(
fast_tokenizer.tokenize(prompt),
['<|im_start|>', '▁system', '\n', 'Answer', '▁the', '▁questions', '.', '<|im_end|>', '<|im_start|>', '▁user', '\n', '<image>', '▁', '\n', 'What', '▁is', '▁shown', '▁in', '▁this', '▁image', '?', '<|im_end|>', '<|im_start|>', '▁assistant', '\n']
) # fmt: skip
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment