Unverified Commit 84562228 authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1276 from kvcache-ai/support_load_safetensor

support safetensor load, delete architectures argument
parents 30eab48a c6aa379d
...@@ -200,7 +200,7 @@ class ForwardBatchInput: ...@@ -200,7 +200,7 @@ class ForwardBatchInput:
device=None, device=None,
tokens: torch.Tensor = None, tokens: torch.Tensor = None,
num_mini_batches: int = 1, num_mini_batches: int = 1,
max_seq_length: int = 1024, # TODO: add to yaml max_seq_length: int = 4096, # TODO: add to yaml
prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config prefill_query_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, # TODO: use config
prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size, prefill_active_length: int = (Config().chunk_size - Config().max_decode_batch_size) // Config().max_prefill_batch_size,
gen_prefill: bool = True, gen_prefill: bool = True,
...@@ -223,12 +223,12 @@ class ForwardBatchInput: ...@@ -223,12 +223,12 @@ class ForwardBatchInput:
decode_querys_info = [] decode_querys_info = []
for i in range(min(decode_batch_size, cuda_lens)): for i in range(min(decode_batch_size, cuda_lens)):
query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, max_seq_length, page_size, device, is_prefill=False, offset=offset) query_info = QueryInfo(i+Config().max_prefill_batch_size, prefill_query_length, 256, page_size, device, is_prefill=False, offset=offset)
offset += max_seq_length // page_size offset += max_seq_length // page_size
if tokens is not None: if tokens is not None:
query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens) query_info.query_tokens[prefill_active_length:prefill_active_length + 1].copy_(tokens)
if decode_active_position is None: if decode_active_position is None:
query_info.active_position = prefill_active_length query_info.active_position = 255
else: else:
query_info.active_position = decode_active_position[i] query_info.active_position = decode_active_position[i]
......
...@@ -39,6 +39,17 @@ def pad_num_tokens(num_tokens): ...@@ -39,6 +39,17 @@ def pad_num_tokens(num_tokens):
def deduplicate_and_sort(lst): def deduplicate_and_sort(lst):
return sorted(set(lst)) return sorted(set(lst))
def generate_cuda_graphs(chunk_size: int) -> list:
# 如果输入不符合要求,assert掉
assert chunk_size <= 1024 or chunk_size % 1024 == 0, "chunk_size must <= 1024 or a multiple of 1024"
base_list = [1, 2, 3, Config().max_batch_size, 64, 256, 512, chunk_size]
if chunk_size <= 1024:
return base_list
multiples = [i for i in range(1024, chunk_size + 1, 1024)]
return deduplicate_and_sort(base_list + multiples)
class ModelRunner: class ModelRunner:
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile.""" """A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
...@@ -56,7 +67,7 @@ class ModelRunner: ...@@ -56,7 +67,7 @@ class ModelRunner:
self.features_buf = None self.features_buf = None
self.output = None self.output = None
self.graph_memory_pool = None self.graph_memory_pool = None
self.cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size]) self.cuda_graphs = generate_cuda_graphs(Config().chunk_size)
self.use_cuda_graph = use_cuda_graph self.use_cuda_graph = use_cuda_graph
self.model_time = 0 self.model_time = 0
self.page_size = page_size self.page_size = page_size
......
...@@ -7,7 +7,7 @@ sys.path.append(current_path+"/../..") ...@@ -7,7 +7,7 @@ sys.path.append(current_path+"/../..")
import numpy as np import numpy as np
# from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin # from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
# from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch # from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_loader import GGUFLoader
import torch import torch
import KTransformersOps import KTransformersOps
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
......
...@@ -9,7 +9,7 @@ from pycuda.compiler import SourceModule ...@@ -9,7 +9,7 @@ from pycuda.compiler import SourceModule
import numpy as np import numpy as np
from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k from ktransformers.util.custom_loader import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
import torch import torch
import KTransformersOps import KTransformersOps
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
......
...@@ -159,5 +159,7 @@ if __name__ == "__main__": ...@@ -159,5 +159,7 @@ if __name__ == "__main__":
prompt = ktansformer_prompt1024 prompt = ktansformer_prompt1024
elif args.prompt_lens == 2048: elif args.prompt_lens == 2048:
prompt = ktansformer_prompt1024 * 2 prompt = ktansformer_prompt1024 * 2
elif args.prompt_lens == 4096:
prompt = ktansformer_prompt1024 * 4
asyncio.run(main(args.concurrent, prompt, max_tokens, model)) asyncio.run(main(args.concurrent, prompt, max_tokens, model))
...@@ -25,7 +25,6 @@ import os ...@@ -25,7 +25,6 @@ import os
from enum import IntEnum from enum import IntEnum
import torch import torch
import KTransformersOps import KTransformersOps
from .custom_loader import SafeTensorLoader
import ctypes import ctypes
import math import math
...@@ -166,238 +165,6 @@ DATA_TYPES = { ...@@ -166,238 +165,6 @@ DATA_TYPES = {
"FP8": 13, "FP8": 13,
} }
class GGUFLoader:
tensor_info: dict
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str):
# Check dir exist
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path)
self.safetensor_loader = None
self.tensor_info = {}
self.gguf_path = gguf_path
self.tensor_file_map = {}
self.file_data_map = {}
self.gguf_file_meta = {}
self.tensor_device_map = {}
# I know this is ugly, but I don't want to change the original code too much
# TODO: merge gguf load and other loads.
safetensor_loader = SafeTensorLoader(gguf_path)
if safetensor_loader.tensor_file_map:
self.safetensor_loader = safetensor_loader
return
# Walk through all the .gguf files in the directory
found_gguf = False
for root, dirs, files in os.walk(gguf_path):
for file in files:
if file.endswith(".gguf"):
found_gguf = True
file_name = os.path.join(root, file)
with open(file_name, "rb") as f:
self.load_gguf(f)
if file_name not in self.file_data_map:
self.file_data_map[file_name] = np.memmap(file_name, mode = 'r')
if not found_gguf:
raise FileNotFoundError(f"Cannot find any .gguf files in: {gguf_path}")
def load_gguf(self, f):
f.seek(0)
assert f.read(4) == b'GGUF'
values = struct.unpack("<IQQ", f.read(4+8+8))
version, n_tensors, n_kv = values
if version != 3:
warnings.warn(f"Version {version} has never been tested, might not work")
info = {}
for _ in range(n_kv):
name = read_value(f, DATA_TYPES["string"])
data_type = struct.unpack("<I", f.read(4))[0]
info[name] = read_value(f, data_type)
tensor_info = {}
for _ in range(n_tensors):
name = read_value(f, DATA_TYPES["string"])
shape_len = read_value(f, DATA_TYPES["uint32"])
shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)]
ggml_type = read_value(f, DATA_TYPES["uint32"])
bad_offset = read_value(f, DATA_TYPES["uint64"])
n_elems = int(math.prod(shape))
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size
np_dims = tuple(reversed(shape))
item_type: npt.DTypeLike
if ggml_type == GGMLQuantizationType.F16:
item_count = n_elems
item_type = np.float16
elif ggml_type == GGMLQuantizationType.F32:
item_count = n_elems
item_type = np.float32
elif ggml_type == GGMLQuantizationType.F64:
item_count = n_elems
item_type = np.float64
elif ggml_type == GGMLQuantizationType.I8:
item_count = n_elems
item_type = np.int8
elif ggml_type == GGMLQuantizationType.I16:
item_count = n_elems
item_type = np.int16
elif ggml_type == GGMLQuantizationType.I32:
item_count = n_elems
item_type = np.int32
elif ggml_type == GGMLQuantizationType.I64:
item_count = n_elems
item_type = np.int64
else:
item_count = n_bytes
item_type = np.uint8
np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
tensor_info[name] = {
"ggml_type": ggml_type,
"shape": shape,
"bad_offset": bad_offset,
"item_type": item_type,
"item_count": item_count,
"np_dims": np_dims
}
start = f.tell()
# Alignment is 32 by default.
# https://github.com/ggerganov/ggml/blob/e1daebbf9d38d510ba456c4d50b4500a73ac2b14/docs/gguf.md?plain=1#L253
alignment = info.get("general.alignment", 32)
# Inconveniently, the offset defined in gguf files is relative to the
# end of the header and is unaligned.
# We need to compute the absolute file offset ourselves instead.
for t in tensor_info.values():
offset = start + t["bad_offset"]
offset += (alignment - offset % alignment) % alignment
t["offset"] = offset
for name in tensor_info:
self.tensor_file_map[name] = f.name
self.tensor_info.update(tensor_info)
self.gguf_file_meta.update(info)
def get_mmap_tensor(self, name):
t = self.tensor_info[name]
mmap_data = self.file_data_map[ self.tensor_file_map[name] ]
offset = t["offset"]
item_type = t["item_type"]
item_count = t["item_count"]
itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count]
def get_undequanted_tensor_and_ggml_type(self, name):
t = self.tensor_info[name]
data = self.get_mmap_tensor(name)
ggml_type = t["ggml_type"]
data = torch.from_numpy(data)
return data, ggml_type
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name]
if device.lower() == "cpu":
print(f"loading expert {expert_id} of {name} with CPU")
shape = t["shape"]
ggml_type = t["ggml_type"]
if ggml_type not in GGML_NAMES:
raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
ggml_name = GGML_NAMES[ggml_type]
# TODO: experts may fused in quant block, split it
assert elements_per_expert % GGML_ELEMENTS_PER_BLOCK[ggml_name] == 0, "experts may fused in quant block, please use CPU dequant"
blocks_per_experts = elements_per_expert // GGML_ELEMENTS_PER_BLOCK[ggml_name]
block_size = GGML_BLOCK_SIZES[ggml_name]
offset = expert_id * block_size * blocks_per_experts
data = data[offset: offset + block_size * blocks_per_experts]
if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values.copy())
if ggml_name == "BF16":
values = values.view(torch.bfloat16)
values = values.view(shape[-2::-1])
return values
def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->torch.Tensor:
t = self.tensor_info[name]
if device.lower() == "cpu":
print(f"loading {name} with CPU")
if target_dtype == None:
target_dtype = torch.get_default_dtype()
shape = t["shape"]
ggml_type = t["ggml_type"]
if ggml_type not in GGML_NAMES:
raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
ggml_name = GGML_NAMES[ggml_type]
data = self.get_mmap_tensor(name)
block_size = GGML_BLOCK_SIZES[ggml_name]
elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
num_elements = int(np.prod(shape))
num_blocks = num_elements // elements_per_block
blocks_per_iter = 16384
if num_blocks > blocks_per_iter: # dequant large tensor
values = torch.empty((num_blocks, elements_per_block), dtype=target_dtype, device=device)
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
if "cuda" in device.lower():
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
else:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values.copy())
cur_values = cur_values.view(-1, elements_per_block)
if ggml_name == "BF16":
cur_values = cur_values.view(torch.bfloat16)
values[blocks_begin : blocks_end] = cur_values
else:
if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
if ggml_name == "BF16":
values = values.view(torch.bfloat16)
values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
elif "attn_k" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count_kv']
values = (values.reshape(n_head, values.shape[0] // n_head // 2, 2, *values.shape[1:])
.swapaxes(1, 2)
.reshape(values.shape))
return values
def read_value(f, data_type): def read_value(f, data_type):
if data_type == DATA_TYPES["string"]: if data_type == DATA_TYPES["string"]:
...@@ -921,6 +688,7 @@ def translate_name_to_gguf(name): ...@@ -921,6 +688,7 @@ def translate_name_to_gguf(name):
name = name.replace(".gate_up_proj.", ".up_proj") name = name.replace(".gate_up_proj.", ".up_proj")
name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp") name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp")
name = name.replace(".mlp.gate.e_score_correction_bias", ".exp_probs_b.bias")
name = name.replace(".mlp.gate", ".ffn_gate_inp") name = name.replace(".mlp.gate", ".ffn_gate_inp")
name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp") name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp") name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
......
This diff is collapsed.
...@@ -22,8 +22,7 @@ from transformers import ( ...@@ -22,8 +22,7 @@ from transformers import (
EtaLogitsWarper, EtaLogitsWarper,
) )
from ktransformers.util.custom_gguf import translate_name_to_gguf from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, GGUFLoader, translate_name_to_gguf
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.operators import base_operator from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
...@@ -98,25 +97,24 @@ def get_all_used_cuda_device(device_map:dict): ...@@ -98,25 +97,24 @@ def get_all_used_cuda_device(device_map:dict):
all_device_list = list(all_device_list) all_device_list = list(all_device_list)
return all_device_list return all_device_list
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""): def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str = ""):
prefix = prefix.replace("orig_module.", "") prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set} persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items()) local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items(): for name, param in local_state.items():
key = prefix + name key = prefix + name
translated_key = translate_name_to_gguf(key) translated_key = key
# TODO: Merge all loader. # TODO: Merge all loader.
# I know this is ugly but lets do it for now. # I know this is ugly but lets do it for now.
if gguf_loader.safetensor_loader is not None: if isinstance(gguf_loader, SafeTensorLoader):
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor load_dequantized_tensor = gguf_loader.load_dequantized_tensor
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
else: else:
load_dequantized_tensor = gguf_loader.load_gguf_tensor load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map tensor_file_map = gguf_loader.tensor_file_map
if translated_key in tensor_file_map: if gguf_loader.has_tensor(translated_key):
target_dtype = torch.get_default_dtype() target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map) device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}") print(f"loading {translated_key} to {device}")
...@@ -128,7 +126,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str ...@@ -128,7 +126,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
#print(load_config.tensor_file_map.keys()) #print(load_config.tensor_file_map.keys())
raise Exception(f"can't find {translated_key} in GGUF file!") raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): def load_weights(module:nn.Module, gguf_loader:ModelLoader, prefix=''):
#print(f"recursively loading weights {prefix}") #print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule): if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix) load_cur_state_dict(module, gguf_loader, prefix)
......
from abc import ABC, abstractmethod
import os
import torch
import numpy as np
from safetensors import safe_open
from typing import Dict, Any, Optional, Union
class ModelLoader(ABC):
"""
Abstract base class for model loaders.
Defines the interface that all model loaders must implement.
"""
@abstractmethod
def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
"""
Load a tensor by name.
Args:
name: Name of the tensor to load
device: Device to load the tensor to
Returns:
The loaded tensor
"""
pass
@classmethod
@abstractmethod
def supports_format(cls, path: str) -> bool:
"""
Check if this loader supports the given path format.
Args:
path: Path to check
Returns:
True if this loader supports the given path, False otherwise
"""
pass
class SafeTensorLoader(ModelLoader):
"""
Loader for SafeTensor format models.
"""
def __init__(self, path: str):
"""
Initialize the SafeTensor loader.
Args:
path: Path to the model directory or file
"""
self.tensor_file_map = {} # Maps tensor names to file paths
self.file_handle_map = {} # Maps file names to file handles
self._load_tensor_file_map(path)
def _load_tensor_file_map(self, path: str) -> None:
"""
Load the tensor file map from the given path.
Args:
path: Path to the model directory or file
"""
# Normalize path to directory
if not os.path.exists(path):
raise FileNotFoundError(f"Path not found: {path}")
if os.path.isfile(path):
folder_path = os.path.dirname(path)
else:
folder_path = path
found_safetensor = False
for root, _, files in os.walk(folder_path):
files = sorted(files)
for file in files:
if file.endswith(".safetensors"):
found_safetensor = True
file_path = os.path.join(root, file)
if file not in self.file_handle_map:
try:
handle = safe_open(file_path, framework="pt")
self.file_handle_map[file] = handle
except Exception as e:
print(f"Error opening Safetensor file {file_path}: {e}")
continue
f = self.file_handle_map.get(file)
if f is None:
continue
try:
for key in f.keys():
self.tensor_file_map[key] = file
except Exception as e:
print(f"Error reading Safetensor file {file_path}: {e}")
if not found_safetensor:
# Not raising an error here allows for the factory to try other loaders
print(f"No Safetensor files found in {folder_path}")
def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
"""
Load a tensor by name.
Args:
name: Name of the tensor to load
device: Device to load the tensor to
Returns:
The loaded tensor
"""
if name not in self.tensor_file_map:
raise KeyError(f"Key {name} not found in Safetensor files")
file = self.tensor_file_map[name]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(name)
return tensor.to(device)
def load_dequantized_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
"""
Load and dequantize a tensor.
Args:
name: Name of the tensor to load
device: Device to load the tensor to
Returns:
The dequantized tensor
"""
if name not in self.tensor_file_map:
raise KeyError(f"Key {name} not found in Safetensor files")
file = self.tensor_file_map[name]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(name).to(device)
if name.endswith(".weight"):
if name[:-7] + ".weight_scale_inv" in self.tensor_file_map:
weight_scale_inv = f.get_tensor(name[:-7] + ".weight_scale_inv").to(device)
# Assuming weight_dequant function is imported
from ktransformers.ktransformers_ext.triton.fp8gemm import weight_dequant
tensor = weight_dequant(tensor, weight_scale_inv)
return tensor.to(device)
def close_all_handles(self) -> None:
"""
Close all file handles.
"""
for handle in self.file_handle_map.values():
handle.close()
self.file_handle_map.clear()
@classmethod
def supports_format(cls, path: str) -> bool:
"""
Check if this loader supports the given path format.
Args:
path: Path to check
Returns:
True if safetensor files are found in the path, False otherwise
"""
# Normalize path to directory
if not os.path.exists(path):
return False
if os.path.isfile(path):
if path.endswith(".safetensors"):
return True
folder_path = os.path.dirname(path)
else:
folder_path = path
# Check if any safetensor files exist in the folder
for root, _, files in os.walk(folder_path):
for file in files:
if file.endswith(".safetensors"):
return True
return False
class GGUFLoader(ModelLoader):
"""
Loader for GGUF format models.
"""
def __init__(self, path: str):
"""
Initialize the GGUF loader.
Args:
path: Path to the model directory or file
"""
# Check if path exists
if not os.path.exists(path):
raise FileNotFoundError(f"GGUF dir not found: {path}")
if os.path.isfile(path):
self.gguf_path = os.path.dirname(path)
else:
self.gguf_path = path
self.tensor_info = {} # Stores tensor metadata
self.tensor_file_map = {} # Maps tensor names to file paths
self.file_data_map = {} # Maps file paths to memory-mapped data
self.gguf_file_meta = {} # Stores GGUF metadata
# For compatibility with the factory pattern
self.safetensor_loader = None
# Scan all GGUF files in the directory
found_gguf = False
for root, _, files in os.walk(self.gguf_path):
for file in files:
if file.endswith(".gguf"):
found_gguf = True
file_path = os.path.join(root, file)
with open(file_path, "rb") as f:
self._load_gguf(f)
if file_path not in self.file_data_map:
self.file_data_map[file_path] = np.memmap(file_path, mode='r')
if not found_gguf:
raise FileNotFoundError(f"Cannot find any .gguf files in: {self.gguf_path}")
def _load_gguf(self, f) -> None:
"""
Load GGUF file metadata and tensor info.
Args:
f: File handle of the GGUF file
"""
# Implementation should follow the original GGUFLoader._load_gguf
# This is a simplified version for illustration
f.seek(0)
assert f.read(4) == b'GGUF'
# Read header
values = struct.unpack("<IQQ", f.read(4+8+8))
version, n_tensors, n_kv = values
if version != 3:
warnings.warn(f"Version {version} has never been tested, might not work")
# Read key-value pairs
info = {}
for _ in range(n_kv):
name = self._read_value(f, 8) # DATA_TYPES["string"]
data_type = struct.unpack("<I", f.read(4))[0]
info[name] = self._read_value(f, data_type)
# Read tensor info
tensor_info = {}
for _ in range(n_tensors):
name = self._read_value(f, 8) # DATA_TYPES["string"]
shape_len = self._read_value(f, 4) # DATA_TYPES["uint32"]
shape = [self._read_value(f, 10) for _ in range(shape_len)] # DATA_TYPES["uint64"]
ggml_type = self._read_value(f, 4) # DATA_TYPES["uint32"]
offset = self._read_value(f, 10) # DATA_TYPES["uint64"]
# Additional tensor metadata would be calculated here
# For brevity, we're omitting the detailed tensor metadata calculation
tensor_info[name] = {
"ggml_type": ggml_type,
"shape": shape,
"offset": offset,
# ... other tensor metadata
}
start = f.tell()
alignment = info.get("general.alignment", 32)
# Calculate actual file offsets
for t in tensor_info.values():
offset = start + t["offset"]
offset += (alignment - offset % alignment) % alignment
t["offset"] = offset
# Update file maps
for name in tensor_info:
self.tensor_file_map[name] = f.name
self.tensor_info.update(tensor_info)
self.gguf_file_meta.update(info)
def _read_value(self, f, data_type) -> Any:
"""
Read a value from the file according to its data type.
Args:
f: File handle
data_type: Type of data to read
Returns:
The read value
"""
# Simplified implementation
# In a complete implementation, this would handle all data types
if data_type == 8: # DATA_TYPES["string"]
length = struct.unpack("<Q", f.read(8))[0]
return f.read(length).decode("utf-8")
elif data_type == 4: # DATA_TYPES["uint32"]
return struct.unpack("<I", f.read(4))[0]
elif data_type == 10: # DATA_TYPES["uint64"]
return struct.unpack("<Q", f.read(8))[0]
# ... handling for other data types
return None
def load_tensor(self, name: str, device: str = "cpu") -> torch.Tensor:
"""
Load a tensor by name.
Args:
name: Name of the tensor to load
device: Device to load the tensor to
Returns:
The loaded tensor
"""
# This should call load_gguf_tensor with the appropriate parameters
return self.load_gguf_tensor(name, device)
def load_gguf_tensor(self, name: str, device: str = "cpu", target_dtype = None) -> torch.Tensor:
"""
Load a GGUF tensor by name.
Args:
name: Name of the tensor to load
device: Device to load the tensor to
target_dtype: Target data type for the tensor
Returns:
The loaded tensor
"""
# Implementation would follow the original GGUFLoader.load_gguf_tensor
# This is a placeholder for illustration
if name not in self.tensor_info:
raise KeyError(f"Tensor {name} not found")
# Actual implementation would dequantize the tensor data
# and return a torch.Tensor
return torch.zeros(1, device=device) # Placeholder
@classmethod
def supports_format(cls, path: str) -> bool:
"""
Check if this loader supports the given path format.
Args:
path: Path to check
Returns:
True if GGUF files are found in the path, False otherwise
"""
# Normalize path to directory
if not os.path.exists(path):
return False
if os.path.isfile(path):
return path.endswith(".gguf")
# Check if any GGUF files exist in the folder
for root, _, files in os.walk(path):
for file in files:
if file.endswith(".gguf"):
return True
return False
\ No newline at end of file
...@@ -6,7 +6,7 @@ import sys ...@@ -6,7 +6,7 @@ import sys
# sys.path.insert(0, "/home/azure/ktransformers") # sys.path.insert(0, "/home/azure/ktransformers")
import argparse import argparse
import torch import torch
from ktransformers.util.custom_gguf import GGUFLoader, translate_name_to_gguf from ktransformers.util.custom_loader import GGUFLoader, translate_name_to_gguf
from safetensors import safe_open from safetensors import safe_open
from safetensors.torch import save_file from safetensors.torch import save_file
import re import re
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment