"examples/vscode:/vscode.git/clone" did not exist on "3fe5c8e8a8fc3733c2be67dffb8a5b3935f94250"
Unverified Commit ac946aac authored by Penut Chen's avatar Penut Chen Committed by GitHub
Browse files

Fix the incorrect permutation of gguf (#31788)



* Fix the incorrect permutation of gguf

* rename num_kv_heads
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* add typing to num_kv_heads
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* rename variables

* refactor permute function name

* update the expected text of the llama3 q4 test

---------
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
parent 6fbea6d2
......@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
from tqdm import tqdm
......@@ -147,10 +149,11 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
num_heads = parsed_parameters["config"]["num_attention_heads"]
tmp_shape = (int(shape[-1] // num_heads // 2), num_heads, 2, shape[0])
weights = weights.reshape(tmp_shape)
weights = weights.transpose(0, 2, 1, 3)
weights = weights.reshape(shape[::-1])
num_kv_heads = parsed_parameters["config"]["num_key_value_heads"]
if ".attn_q." in name:
weights = reverse_permute_weights(weights, num_heads, num_heads)
elif ".attn_k." in name:
weights = reverse_permute_weights(weights, num_heads, num_kv_heads)
for tensor_name in tensor_key_mapping:
if tensor_name in name:
......@@ -163,3 +166,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
logger.info(f"Some keys of the GGUF file were not considered: {reader_keys}")
return parsed_parameters
def reverse_permute_weights(weights: np.ndarray, n_head: int, num_kv_heads: Optional[int] = None) -> np.ndarray:
# Original permutation implementation
# https://github.com/ggerganov/llama.cpp/blob/a38b884c6c4b0c256583acfaaabdf556c62fabea/convert_hf_to_gguf.py#L1402-L1408
if num_kv_heads is not None and n_head != num_kv_heads:
n_head = num_kv_heads
dim = weights.shape[0] // n_head // 2
w = weights.reshape(n_head, dim, 2, *weights.shape[1:])
return w.swapaxes(2, 1).reshape(weights.shape)
......@@ -188,8 +188,7 @@ class GgufIntegrationTests(unittest.TestCase):
text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)
EXPECTED_TEXT = "Hello, I am new to this forum. I am"
EXPECTED_TEXT = "Hello, I am interested in [The Park]\nThe"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
def test_tokenization_xnli(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment