Unverified Commit e4628434 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Add Qwen2 GGUF loading support (#31175)

* add qwen2 gguf support

* Update docs

* fix qwen2 tokenizer

* add qwen2 gguf test

* fix typo in qwen2 gguf test

* format code

* Remove mistral, clarify the error message

* format code

* add typing and update docstring
parent df848acc
......@@ -63,6 +63,7 @@ For now the supported model architectures are the architectures that have been v
- LLaMa
- Mistral
- Qwen2
## Example usage
......
......@@ -401,9 +401,11 @@ class HerbertConverter(Converter):
class Qwen2Converter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
merges = list(self.original_tokenizer.bpe_ranks.keys())
def converted(self, vocab: Dict[str, int] = None, merges: List[Tuple[str, str]] = None) -> Tokenizer:
if not vocab:
vocab = self.original_tokenizer.encoder
if not merges:
merges = list(self.original_tokenizer.bpe_ranks.keys())
tokenizer = Tokenizer(
BPE(
......
......@@ -25,7 +25,7 @@ from tokenizers import Tokenizer, decoders
from tokenizers.models import BPE
from .. import AddedToken
from ..convert_slow_tokenizer import LlamaConverter
from ..convert_slow_tokenizer import LlamaConverter, Qwen2Converter
from ..utils import logging
from ..utils.logging import tqdm
......@@ -101,6 +101,21 @@ GGUF_TENSOR_MAPPING = {
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
"qwen2": {
"token_embd": "model.embed_tokens",
"blk": "model.layers",
"ffn_up": "mlp.up_proj",
"ffn_down": "mlp.down_proj",
"ffn_gate": "mlp.gate_proj",
"ffn_norm": "post_attention_layernorm",
"attn_norm": "input_layernorm",
"attn_q": "self_attn.q_proj",
"attn_v": "self_attn.v_proj",
"attn_k": "self_attn.k_proj",
"attn_output": "self_attn.o_proj",
"output.weight": "lm_head.weight",
"output_norm": "model.norm",
},
}
......@@ -133,8 +148,19 @@ GGUF_CONFIG_MAPPING = {
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"qwen2": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
"tokenizer": {
"ggml.model": "model_type",
"ggml.bos_token_id": "bos_token_id",
"ggml.eos_token_id": "eos_token_id",
"ggml.unknown_token_id": "unk_token_id",
......@@ -490,14 +516,15 @@ class GGUFTokenizerSkeleton:
for k, v in dict_.items():
setattr(self, k, v)
if not hasattr(self, "tokens") or not hasattr(self, "scores"):
raise ValueError("tokens and scores need to be passed for a LLaMa tokenizer to be instantiated.")
else:
if not hasattr(self, "merges"):
if not hasattr(self, "tokens") or not hasattr(self, "scores"):
raise ValueError(
"tokens and scores need to be passed for a LLaMa tokenizer without merges to be instantiated."
)
tokens = self.tokens
scores = self.scores
vocab = {t: scores[i] for i, t in enumerate(tokens)}
if not hasattr(self, "merges"):
logger.warning("Merges were not in checkpoint, building merges on the fly.")
merges = []
for merge, piece_score in tqdm(vocab.items()):
......@@ -562,16 +589,37 @@ class GGUFLlamaConverter(LlamaConverter):
return decoders.Sequence(sequence)
class GGUFQwen2Converter(Qwen2Converter):
def __init__(self, tokenizer_dict):
self.original_tokenizer = GGUFTokenizerSkeleton(tokenizer_dict)
def converted(self) -> Tokenizer:
vocab = {word: i for i, word in enumerate(self.original_tokenizer.tokens)}
merges = self.original_tokenizer.merges
tokenizer = super().converted(vocab, merges)
tokenizer.add_special_tokens(
[
AddedToken("<|endoftext|>", normalized=False, special=True),
AddedToken("<|im_start|>", normalized=False, special=True),
AddedToken("<|im_end|>", normalized=False, special=True),
]
)
return tokenizer
GGUF_TO_FAST_CONVERTERS = {
"llama": GGUFLlamaConverter,
"qwen2": GGUFQwen2Converter,
}
def convert_gguf_tokenizer(tokenizer_dict) -> Tokenizer:
def convert_gguf_tokenizer(architecture, tokenizer_dict) -> Tokenizer:
"""
Utilities to convert a slow tokenizer instance in a fast tokenizer instance.
Args:
architecture (`str`): The model architecture derived from gguf file.
transformer_tokenizer ([`~tokenization_utils_base.PreTrainedTokenizer`]):
Instance of a slow tokenizer to convert in the backend tokenizer for
[`~tokenization_utils_base.PreTrainedTokenizerFast`].
......@@ -580,6 +628,6 @@ def convert_gguf_tokenizer(tokenizer_dict) -> Tokenizer:
A instance of [`~tokenizers.Tokenizer`] to be used as the backend tokenizer of a
[`~tokenization_utils_base.PreTrainedTokenizerFast`]
"""
tokenizer_class_name = tokenizer_dict["tokenizer_type"]
tokenizer_class_name = architecture
converter_class = GGUF_TO_FAST_CONVERTERS[tokenizer_class_name]
return converter_class(tokenizer_dict).converted()
......@@ -118,8 +118,8 @@ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
)
super().__init__(
vocab_file,
merges_file,
vocab_file=vocab_file,
merges_file=merges_file,
tokenizer_file=tokenizer_file,
unk_token=unk_token,
bos_token=bos_token,
......
......@@ -118,8 +118,10 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
fast_tokenizer = convert_slow_tokenizer(slow_tokenizer)
elif gguf_file is not None:
# We need to convert a slow tokenizer to build the backend
tokenizer_dict = load_gguf_checkpoint(kwargs.get("vocab_file"))["tokenizer"]
fast_tokenizer = convert_gguf_tokenizer(tokenizer_dict)
gguf_param = load_gguf_checkpoint(kwargs.get("vocab_file"))
architecture = gguf_param["config"]["model_type"]
tokenizer_dict = gguf_param["tokenizer"]
fast_tokenizer = convert_gguf_tokenizer(architecture, tokenizer_dict)
elif self.slow_tokenizer_class is not None:
# We need to create and convert a slow tokenizer to build the backend
slow_tokenizer = self.slow_tokenizer_class(*args, **kwargs)
......
......@@ -31,6 +31,7 @@ class GgufIntegrationTests(unittest.TestCase):
original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
......@@ -41,6 +42,7 @@ class GgufIntegrationTests(unittest.TestCase):
q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
example_text = "Hello"
......@@ -157,6 +159,18 @@ class GgufIntegrationTests(unittest.TestCase):
EXPECTED_TEXT = "Hello,\n\nI'm trying to create a"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_qwen2_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.qwen2_model_id, gguf_file=self.q4_0_qwen2_model_id, device_map="auto", torch_dtype=torch.float16
)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello.jsoup\n\nI am a beginner"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_tokenization_xnli(self):
import tqdm
from datasets import load_dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment