Commit 5ec33d04 authored by Atream's avatar Atream
Browse files

optimize gguf dequant, save mem, support Q2_K

use marlin for lm_head, lm_head only calc last token for prefill
extend context window to 19K for DeepSeek-V3/R1 within 24GB VRAM
parent 7e1fe256
...@@ -135,7 +135,18 @@ ...@@ -135,7 +135,18 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)" name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.block_sparse_moe$" name: "^model\\.layers\\..*\\.block_sparse_moe$"
class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock class: ktransformers.models.modeling_mixtral.MixtralSparseMoeBlock
......
...@@ -77,9 +77,19 @@ ...@@ -77,9 +77,19 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "(^model.norm)|(^lm_head)" name: "(^model.norm)"
replace: replace:
class: "default" class: "default"
kwargs: kwargs:
......
...@@ -15,6 +15,16 @@ ...@@ -15,6 +15,16 @@
prefill_device: "cuda" prefill_device: "cuda"
generate_op: "KLinearMarlin" generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch" prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match: - match:
name: "^model\\.layers\\..*\\.mlp$" name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
......
...@@ -25,6 +25,7 @@ import os ...@@ -25,6 +25,7 @@ import os
from enum import IntEnum from enum import IntEnum
import torch import torch
import KTransformersOps import KTransformersOps
import ctypes
class GGMLQuantizationType(IntEnum): class GGMLQuantizationType(IntEnum):
F32 = 0 F32 = 0
...@@ -307,7 +308,7 @@ class GGUFLoader: ...@@ -307,7 +308,7 @@ class GGUFLoader:
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype) values = GGML_DEQUANTIZE_GPU[ggml_name](data, device, target_dtype)
else: else:
values = GGML_DEQUANTIZE[ggml_name](data) values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values) values = torch.from_numpy(values.copy())
values = values.view(shape[-2::-1]) values = values.view(shape[-2::-1])
...@@ -343,7 +344,7 @@ class GGUFLoader: ...@@ -343,7 +344,7 @@ class GGUFLoader:
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype) cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
else: else:
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size]) cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
cur_values = torch.from_numpy(cur_values) cur_values = torch.from_numpy(cur_values.copy())
cur_values = cur_values.view(-1, elements_per_block) cur_values = cur_values.view(-1, elements_per_block)
values[blocks_begin : blocks_end] = cur_values values[blocks_begin : blocks_end] = cur_values
...@@ -455,11 +456,13 @@ def dequantize_q2_k(data): ...@@ -455,11 +456,13 @@ def dequantize_q2_k(data):
def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): def dequantize_q2_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q2_K"] block_size = GGML_BLOCK_SIZES["Q2_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q2_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
return KTransformersOps.dequantize_q2_k(data.data, data.size, block_size, device, target_dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q2_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q3_k(data): def dequantize_q3_k(data):
# C implementation # C implementation
...@@ -505,11 +508,13 @@ def dequantize_q3_k(data): ...@@ -505,11 +508,13 @@ def dequantize_q3_k(data):
def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): def dequantize_q3_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q3_K"] block_size = GGML_BLOCK_SIZES["Q3_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q3_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
return KTransformersOps.dequantize_q3_k(data.data, data.size, block_size, device, target_dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q3_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q4_k(data): def dequantize_q4_k(data):
# C implementation # C implementation
...@@ -534,11 +539,14 @@ def dequantize_q4_k(data): ...@@ -534,11 +539,14 @@ def dequantize_q4_k(data):
return factors * qs2 - offsets return factors * qs2 - offsets
def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): def dequantize_q4_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q4_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q4_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
return KTransformersOps.dequantize_q4_k(data.data, data.size, 144, device, target_dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q4_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q5_k(data): def dequantize_q5_k(data):
# C implementation # C implementation
...@@ -598,11 +606,13 @@ def dequantize_q5_k(data): ...@@ -598,11 +606,13 @@ def dequantize_q5_k(data):
def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()): def dequantize_q5_k_gpu(data, device:str ="cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q5_K"] block_size = GGML_BLOCK_SIZES["Q5_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q5_K"]
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device) device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable, # TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor. # the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
return KTransformersOps.dequantize_q5_k(data.data, data.size, block_size, device, target_dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q5_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q6_k(data): def dequantize_q6_k(data):
# C implementation # C implementation
...@@ -655,10 +665,12 @@ def dequantize_q6_k(data): ...@@ -655,10 +665,12 @@ def dequantize_q6_k(data):
# @torch.jit.script # @torch.jit.script
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["Q6_K"] block_size = GGML_BLOCK_SIZES["Q6_K"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q6_K"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
return KTransformersOps.dequantize_q6_k(data.data, data.size, block_size, device, target_dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q6_k(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
...@@ -694,10 +706,12 @@ def dequantize_iq4_xs(data): ...@@ -694,10 +706,12 @@ def dequantize_iq4_xs(data):
def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()): def dequantize_iq4_xs_gpu(data: np.ndarray, device:str = "cuda", target_dtype = torch.get_default_dtype()):
block_size = GGML_BLOCK_SIZES["IQ4_XS"] block_size = GGML_BLOCK_SIZES["IQ4_XS"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["IQ4_XS"]
device = torch.device(device) device = torch.device(device)
num_blocks = len(data) // block_size num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
return KTransformersOps.dequantize_iq4_xs(data.data, data.size, block_size, device, target_dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_iq4_xs(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_q4_0(data): def dequantize_q4_0(data):
# C implementation # C implementation
...@@ -753,10 +767,13 @@ def dequantize_q8_0(data): ...@@ -753,10 +767,13 @@ def dequantize_q8_0(data):
def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()): def dequantize_q8_0_gpu(data, device:str = "cuda", target_dtype = torch.get_default_dtype()):
# C struct definition # C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
block_size = GGML_BLOCK_SIZES["Q8_0"]
ele_per_blk = GGML_ELEMENTS_PER_BLOCK["Q8_0"]
device = torch.device(device) device = torch.device(device)
data = np.frombuffer(data, dtype=data.dtype) data = np.frombuffer(data, dtype=data.dtype)
return KTransformersOps.dequantize_q8_0(data.data, data.size, 34, device, target_dtype) c_pointer = ctypes.addressof(ctypes.cast(data.ctypes.data, ctypes.POINTER(ctypes.c_int8)).contents)
return KTransformersOps.dequantize_q8_0(c_pointer, data.size, block_size, ele_per_blk, device, target_dtype)
def dequantize_f32(data): def dequantize_f32(data):
...@@ -764,8 +781,8 @@ def dequantize_f32(data): ...@@ -764,8 +781,8 @@ def dequantize_f32(data):
def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()): def dequantize_f32_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float32) data = np.frombuffer(data, dtype=np.float32)
res = torch.from_numpy(data) res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
res_gpu.copy_(res) res_gpu.copy_(res)
return res_gpu return res_gpu
...@@ -774,7 +791,14 @@ def dequantize_f16(data): ...@@ -774,7 +791,14 @@ def dequantize_f16(data):
def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()): def dequantize_f16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16) data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data) res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device, dtype=target_dtype)
res_gpu.copy_(res)
return res_gpu
def dequantize_bf16_gpu(data, device, target_dtype = torch.get_default_dtype()):
data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data.copy())
res_gpu = torch.empty_like(res, device=device) res_gpu = torch.empty_like(res, device=device)
res_gpu.copy_(res) res_gpu.copy_(res)
return res_gpu return res_gpu
...@@ -797,7 +821,7 @@ GGML_DEQUANTIZE = { ...@@ -797,7 +821,7 @@ GGML_DEQUANTIZE = {
GGML_DEQUANTIZE_GPU = { GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu, "F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu, "F16": dequantize_f16_gpu,
"BF16": dequantize_f16_gpu, "BF16": dequantize_bf16_gpu,
"Q4_0": dequantize_q4_0_gpu, "Q4_0": dequantize_q4_0_gpu,
"Q5_0": dequantize_q5_0_gpu, "Q5_0": dequantize_q5_0_gpu,
"Q8_0": dequantize_q8_0_gpu, "Q8_0": dequantize_q8_0_gpu,
......
...@@ -79,7 +79,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str ...@@ -79,7 +79,7 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
raise Exception(f"can't find {translated_key} in GGUF file!") raise Exception(f"can't find {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''): def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix=''):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}") #print(f"recursively loading weights {prefix}")
if not isinstance(module, base_operator.BaseInjectedModule): if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix) load_cur_state_dict(module, gguf_loader, prefix)
for name, child in module._modules.items(): for name, child in module._modules.items():
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment