Unverified Commit 1c122a46 authored by Penut Chen's avatar Penut Chen Committed by GitHub
Browse files

Support dequantizing GGUF FP16 format (#31783)

* support gguf fp16

* support gguf bf16 with pytorch

* add gguf f16 test

* remove bf16
parent af0e4b7b
...@@ -36,6 +36,7 @@ logger = logging.get_logger(__name__) ...@@ -36,6 +36,7 @@ logger = logging.get_logger(__name__)
# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md # Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
GGML_TYPES = { GGML_TYPES = {
"F32": 0, "F32": 0,
"F16": 1,
"Q4_0": 2, "Q4_0": 2,
"Q8_0": 8, "Q8_0": 8,
"Q2_K": 10, "Q2_K": 10,
...@@ -489,6 +490,8 @@ def dequantize_q5_k(data): ...@@ -489,6 +490,8 @@ def dequantize_q5_k(data):
def load_dequant_gguf_tensor(shape, ggml_type, data): def load_dequant_gguf_tensor(shape, ggml_type, data):
if ggml_type == GGML_TYPES["F32"]: if ggml_type == GGML_TYPES["F32"]:
values = data values = data
elif ggml_type == GGML_TYPES["F16"]:
values = data
elif ggml_type == GGML_TYPES["Q8_0"]: elif ggml_type == GGML_TYPES["Q8_0"]:
values = dequantize_q8_0(data) values = dequantize_q8_0(data)
elif ggml_type == GGML_TYPES["Q4_0"]: elif ggml_type == GGML_TYPES["Q4_0"]:
......
...@@ -33,6 +33,7 @@ class GgufIntegrationTests(unittest.TestCase): ...@@ -33,6 +33,7 @@ class GgufIntegrationTests(unittest.TestCase):
mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF" qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF" llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
...@@ -45,6 +46,7 @@ class GgufIntegrationTests(unittest.TestCase): ...@@ -45,6 +46,7 @@ class GgufIntegrationTests(unittest.TestCase):
q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf" q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf" q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf" q4_llama3_model_id = "Meta-Llama-3-8B-Q4_K_M.gguf"
f16_tinyllama_model_id = "TinyLlama-1.1B-Chat-v1.0.FP16.gguf"
example_text = "Hello" example_text = "Hello"
...@@ -149,6 +151,18 @@ class GgufIntegrationTests(unittest.TestCase): ...@@ -149,6 +151,18 @@ class GgufIntegrationTests(unittest.TestCase):
EXPECTED_TEXT = "Hello, World!\n\n5. Use a library" EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_f16(self):
tokenizer = AutoTokenizer.from_pretrained(self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id
).to(torch_device)
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, World!\n\n5. Node.js"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_mistral_q4_0(self): def test_mistral_q4_0(self):
tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id) tokenizer = AutoTokenizer.from_pretrained(self.mistral_model_id, gguf_file=self.q4_0_mistral_model_id)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment