Commit 7e1fe256 authored by Atream's avatar Atream
Browse files

optimize GPU

parent cf4da5fd
...@@ -20,19 +20,19 @@ ...@@ -20,19 +20,19 @@
PYBIND11_MODULE(KTransformersOps, m) { PYBIND11_MODULE(KTransformersOps, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"), py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
......
...@@ -17,19 +17,19 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de ...@@ -17,19 +17,19 @@ torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device de
PYBIND11_MODULE(cudaops, m) { PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.", m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.", m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("device"), py::arg("target_dtype"));
m.def("test", &test, "Function to test."); m.def("test", &test, "Function to test.");
} }
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const torch::Device device, const torch::ScalarType target_dtype);
...@@ -168,10 +168,7 @@ def local_chat( ...@@ -168,10 +168,7 @@ def local_chat(
if mode == 'long_context': if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
......
...@@ -5,6 +5,18 @@ ...@@ -5,6 +5,18 @@
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
......
...@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext): ...@@ -25,10 +25,10 @@ class KTransformersThreadContext(TransformersThreadContext):
class KTransformersInterface(TransformersInterface): class KTransformersInterface(TransformersInterface):
def __init__(self, args: ConfigArgs = default_args): def __init__(self, args: ConfigArgs = default_args):
self.args = args self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
torch.set_default_dtype(config.torch_dtype)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
......
...@@ -285,7 +285,7 @@ class GGUFLoader: ...@@ -285,7 +285,7 @@ class GGUFLoader:
itemsize = int(np.empty([], dtype = item_type).itemsize) itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count] return mmap_data[offset : offset + itemsize * item_count]
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "gpu")->torch.Tensor: def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading expert {expert_id} of {name} with CPU") print(f"loading expert {expert_id} of {name} with CPU")
...@@ -304,7 +304,7 @@ class GGUFLoader: ...@@ -304,7 +304,7 @@ class GGUFLoader:
data = data[offset: offset + block_size * blocks_per_experts] data = data[offset: offset + block_size * blocks_per_experts]
if "cuda" in device.lower(): if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values)
...@@ -313,7 +313,7 @@ class GGUFLoader: ...@@ -313,7 +313,7 @@ class GGUFLoader:
return values return values
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor: def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name] t = self.tensor_info[name]
if device.lower() == "cpu": if device.lower() == "cpu":
print(f"loading {name} with CPU") print(f"loading {name} with CPU")
...@@ -328,16 +328,36 @@ class GGUFLoader: ...@@ -328,16 +328,36 @@ class GGUFLoader:
data = self.get_mmap_tensor(name) data = self.get_mmap_tensor(name)
if "cuda" in device.lower(): block_size = GGML_BLOCK_SIZES[ggml_name]
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device) elements_per_block = GGML_ELEMENTS_PER_BLOCK[ggml_name]
#values = GGML_DEQUANTIZE[ggml_name](data) num_elements = int(np.prod(shape))
#print("load_gguf_tensor") num_blocks = num_elements // elements_per_block
#values = torch.from_numpy(values).to(device = device)
blocks_per_iter = 16384
if num_blocks > blocks_per_iter: # dequant large tensor
values = torch.empty((num_blocks, elements_per_block), dtype=torch.float, device=device)
for i in range( (num_blocks + blocks_per_iter - 1) // blocks_per_iter):
blocks_begin = i * blocks_per_iter
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
if "cuda" in device.lower():
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
else:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values)
cur_values = cur_values.view(-1, elements_per_block)
values[blocks_begin : blocks_end] = cur_values
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) if "cuda" in device.lower():
values = torch.from_numpy(values) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
if ggml_name == "BF16": if ggml_name == "BF16":
values = values.view(torch.bfloat16) values = values.view(torch.bfloat16)
values = values.view(shape[::-1]) values = values.view(shape[::-1])
if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]: if "attn_q" in name and self.gguf_file_meta['general.architecture'] in ["llama"]:
n_head = self.gguf_file_meta['llama.attention.head_count'] n_head = self.gguf_file_meta['llama.attention.head_count']
...@@ -433,14 +453,13 @@ def dequantize_q2_k(data): ...@@ -433,14 +453,13 @@ def dequantize_q2_k(data):
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data, device:str ="cuda"): def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q2_K"] block_size = GGML_BLOCK_SIZES["Q2_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q2_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q2_k(data, block_size, device)
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
...@@ -484,14 +503,13 @@ def dequantize_q3_k(data): ...@@ -484,14 +503,13 @@ def dequantize_q3_k(data):
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1) ], axis=1)
def dequantize_q3_k_gpu(data, device:str ="cuda"): def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q3_K"] block_size = GGML_BLOCK_SIZES["Q3_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q3_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q3_k(data, block_size, device)
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
...@@ -515,13 +533,12 @@ def dequantize_q4_k(data): ...@@ -515,13 +533,12 @@ def dequantize_q4_k(data):
# Dequantize final weights using scales and offsets # Dequantize final weights using scales and offsets
return factors * qs2 - offsets return factors * qs2 - offsets
def dequantize_q4_k_gpu(data, device:str ="cuda"): def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q4_k(data.data, data.size, 144, device, target_dtype)
return KTransformersOps.dequantize_q4_k(data, 144, device)
def dequantize_q5_k(data): def dequantize_q5_k(data):
# C implementation # C implementation
...@@ -579,14 +596,13 @@ def dequantize_q5_k(data): ...@@ -579,14 +596,13 @@ def dequantize_q5_k(data):
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1) ], axis=1)
def dequantize_q5_k_gpu(data, device:str ="cuda"): def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q5_K"] block_size = GGML_BLOCK_SIZES["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data) return KTransformersOps.dequantize_q5_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q5_k(data, block_size, device)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
...@@ -637,13 +653,12 @@ def dequantize_q6_k(data): ...@@ -637,13 +653,12 @@ def dequantize_q6_k(data):
], axis=1) ], axis=1)
# @torch.jit.script # @torch.jit.script
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q6_K"] block_size = GGML_BLOCK_SIZES["Q6_K"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) return KTransformersOps.dequantize_q6_k(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_q6_k(data, block_size, device)
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
...@@ -677,13 +692,12 @@ def dequantize_iq4_xs(data): ...@@ -677,13 +692,12 @@ def dequantize_iq4_xs(data):
return y.flatten() return y.flatten()
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda"): def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["IQ4_XS"] block_size = GGML_BLOCK_SIZES["IQ4_XS"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) return KTransformersOps.dequantize_iq4_xs(data.data, data.size, block_size, device, target_dtype)
return KTransformersOps.dequantize_iq4_xs(data, block_size, device)
def dequantize_q4_0(data): def dequantize_q4_0(data):
# C implementation # C implementation
...@@ -700,7 +714,7 @@ def dequantize_q4_0(data): ...@@ -700,7 +714,7 @@ def dequantize_q4_0(data):
scales * ((qs >> 4).astype(np.int8) - 8), scales * ((qs >> 4).astype(np.int8) - 8),
], axis=1) ], axis=1)
def dequantize_q4_0_gpu(data): def dequantize_q4_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q5_0(data): def dequantize_q5_0(data):
...@@ -724,7 +738,7 @@ def dequantize_q5_0(data): ...@@ -724,7 +738,7 @@ def dequantize_q5_0(data):
scales * x1, scales * x1,
], axis=1) ], axis=1)
def dequantize_q5_0_gpu(data): def dequantize_q5_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
raise NotImplementedError() raise NotImplementedError()
def dequantize_q8_0(data): def dequantize_q8_0(data):
...@@ -736,20 +750,19 @@ def dequantize_q8_0(data): ...@@ -736,20 +750,19 @@ def dequantize_q8_0(data):
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs return scales * qs
def dequantize_q8_0_gpu(data, device:str = "cuda"): def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
# C struct definition # C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
device = torch.device(device) device = torch.device(device)
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data) return KTransformersOps.dequantize_q8_0(data.data, data.size, 34, device, target_dtype)
return KTransformersOps.dequantize_q8_0(data, 34, device)
def dequantize_f32(data): def dequantize_f32(data):
return np.frombuffer(data, dtype=np.float32) return np.frombuffer(data, dtype=np.float32)
def dequantize_f32_gpu(data, device): def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float32) data = np.frombuffer(data, dtype=np.float32)
res = torch.from_numpy(data) res = torch.from_numpy(data)
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device)
...@@ -759,7 +772,7 @@ def dequantize_f32_gpu(data, device): ...@@ -759,7 +772,7 @@ def dequantize_f32_gpu(data, device):
def dequantize_f16(data): def dequantize_f16(data):
return np.frombuffer(data, dtype=np.float16) return np.frombuffer(data, dtype=np.float16)
def dequantize_f16_gpu(data, device): def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16) data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data) res = torch.from_numpy(data)
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment