Unverified Commit 8db6a4d4 authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge branch 'main' into main

parents cea07d19 3c8c5805
......@@ -16,6 +16,7 @@ from ktransformers.models.modeling_deepseek import DeepseekV2Attention, apply_ro
from typing import Optional, Tuple
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import get_compute_capability
import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.cache_utils import Cache
......@@ -48,12 +49,14 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
absorb_for_prefill: bool = False,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
self.mla_wrapper = None
self.absorb_for_prefill = absorb_for_prefill
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
......@@ -242,7 +245,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(1, 2)
assert q_nope.is_contiguous()
#assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
......@@ -282,6 +285,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2)
attn_output = torch.matmul(attn_output, out_absorb.mT)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output)
......@@ -380,7 +384,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
# decode
if q_len == 1:
if q_len == 1 or self.absorb_for_prefill:
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
compressed_kv_with_k_pe, page_table = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
......@@ -395,29 +399,42 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
q_nope = q_nope.transpose(1, 2) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(1, 2)
assert q_nope.is_contiguous()
q_nope = q_nope.contiguous()
#assert q_nope.is_contiguous()
# q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
# q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
q_nope.squeeze_(1)
q_pe.squeeze_(1)
q_nope.squeeze_(0)
q_pe.squeeze_(0)
# flash attn doesn't support head_dim bigger than 256, use flashinfer
if self.mla_wrapper is None:
self.mla_wrapper = MLAWrapperSingleton.get_instance(self.device, 1, past_key_value.max_pages, use_cuda_graph = True)
if self.mla_wrapper.need_plan:
self.mla_wrapper.need_plan = False
if self.mla_wrapper.need_plan:
self.mla_wrapper.need_plan = False
if q_len == 1:
self.mla_wrapper.plan(None,None,None,
position_ids.squeeze(1)+1,
self.num_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
past_key_value.page_size,
self.softmax_scale,
q_nope.dtype,
compressed_kv.dtype)
position_ids.squeeze(1)+1,
self.num_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
past_key_value.page_size,
self.softmax_scale,
q_nope.dtype,
compressed_kv.dtype)
else:
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device=self.device)
kv_len_arr = torch.tensor([position_ids[0, -1].item()+1], dtype=torch.int32, device=self.device)
self.mla_wrapper.plan(qo_indptr,None,None,
kv_len_arr,
self.num_heads,
self.kv_lora_rank,
self.qk_rope_head_dim,
past_key_value.page_size,
self.softmax_scale,
q_nope.dtype,
compressed_kv.dtype)
attn_output = self.mla_wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(bsz, q_len, self.num_heads, self.kv_lora_rank)
"""
k = (
torch.cat([compressed_kv, k_pe], dim=-1)
......@@ -443,10 +460,11 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
attn_output = attn_output.transpose(1, 2) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
attn_output = torch.matmul(attn_output, out_absorb.mT) # [bsz, self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(1, 2).contiguous() # [bsz, q_len, self.num_heads, self.kv_lora_rank]
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) # [bsz, q_len, self.num_heads * self.v_head_dim]
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
else:
if past_key_value is not None:
......@@ -571,7 +589,8 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if os.name == 'nt':
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
return self.forward_windows(
hidden_states,
attention_mask,
......
......@@ -245,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
down_type = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
if self.gguf_loader.safetensor_loader is not None:
# using a temp ugly way to temprary load the tensor
gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
......@@ -450,9 +459,9 @@ class KExpertsTorch(KExpertsBase):
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
self.up = torch.cat(self.up, dim=0)
self.gate = torch.cat(self.gate, dim=0)
self.down = torch.cat(self.down, dim=0)
self.up = torch.stack(self.up, dim=0)
self.gate = torch.stack(self.gate, dim=0)
self.down = torch.stack(self.down, dim=0)
return
def unload(self):
......
......@@ -9,7 +9,7 @@ flashinfer_enabled = False
try:
import flashinfer
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
flashinfer_enabled = True
print("found flashinfer")
except ImportError:
......@@ -122,7 +122,7 @@ class MLAWrapper():
if kv_indices is None:
assert self.max_batch_size == 1
kv_indices = self.kv_indices_buf
self.wrapper.plan(
qo_indptr,
kv_indptr,
......@@ -132,14 +132,14 @@ class MLAWrapper():
head_dim_ckv,
head_dim_kpe,
page_size,
False, # causal is False for decoding
True, # causal
sm_scale,
q_data_type,
kv_data_type,
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse)
return self.wrapper.run(q_nope, q_pe, ckv, k_pe, return_lse = return_lse)
class MLAWrapperSingleton():
wrappers:dict = {}
......@@ -179,6 +179,24 @@ class MLAWrapperSingleton():
sm_scale,
q_data_type,
kv_data_type,)
wrapper.need_plan = False
@classmethod
def need_plan_all(cls):
for device, wrapper in cls.wrappers.items():
wrapper.need_plan = True
@classmethod
def reset_buffer(cls):
for device, wrapper in cls.wrappers.items():
wrapper.qo_indptr_buf[1] = 1 # assert max_batch_size=1 here.
@classmethod
def update_buffer(cls, max_pages):
for device, wrapper in cls.wrappers.items():
wrapper.kv_indptr_buf[1] = max_pages # assert max_batch_size=1 here.
wrapper.kv_indices_buf = torch.arange(0, max_pages, dtype=torch.int32, device=device)
wrapper.wrapper._kv_indices_buf = wrapper.kv_indices_buf
if __name__ == "__main__":
......@@ -187,8 +205,9 @@ if __name__ == "__main__":
page_size = 64
num_heads = 128
q_nope = torch.randn((1, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((1, num_heads, 64), dtype=torch.bfloat16, device="cuda")
q_len = 10
q_nope = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
ckv = torch.randn((max_pages, page_size, 512), dtype=torch.bfloat16, device="cuda")
k_pe = torch.randn((max_pages, page_size, 64), dtype=torch.bfloat16, device="cuda")
......@@ -199,10 +218,10 @@ if __name__ == "__main__":
max_pages,
)
kv_len_arr = torch.tensor([10], dtype=torch.int32, device="cuda")
kv_len_arr = torch.tensor([q_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
wrapper.plan(
None,
qo_indptr,
None,
None,
kv_len_arr,
......@@ -216,6 +235,7 @@ if __name__ == "__main__":
)
attn_output = wrapper.run(q_nope, q_pe, ckv, k_pe)
print(attn_output.shape)
k = (
torch.cat([ckv, k_pe], dim=-1)
......@@ -235,6 +255,7 @@ if __name__ == "__main__":
False,
192 ** (-0.5)
)
print(attn_ref.shape)
torch.testing.assert_close(attn_output, attn_ref, rtol=1e-3, atol=1e-3)
print("test past")
\ No newline at end of file
......@@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
for key in keys:
key = ".".join(key.split(".")[:-1])
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
if self.gguf_loader.safetensor_loader is not None:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
weight_type = weight.dtype
e_score_correction_bias_type = e_score_correction_bias.dtype
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
tensors = self.load_multi(key, targets, device=device)
weight = tensors[".ffn_gate_inp.weight"]
......@@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device)
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
def unload(self):
if self.weight is not None:
......
......@@ -26,6 +26,7 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
)
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from abc import ABC, abstractmethod
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
......@@ -78,7 +79,13 @@ class KLinearBase(ABC):
keys = [self.key]
for key in keys:
if key + ".weight" in self.gguf_loader.tensor_file_map:
if self.gguf_loader.safetensor_loader is not None:
# using safetensor_loader
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map:
tensors = self.load_multi(key, ["weight", "bias"], device=device)
tensor = tensors["weight"]
......@@ -169,7 +176,61 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearFP8(KLinearBase):
# this kernel requires special handling for weight
# Please load the weight file downloaded from KVCache.AI
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
weight: torch.Tensor
scale_w: torch.Tensor
bias: torch.Tensor
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
block_size: int = 128,
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.block_size = block_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
orig_dtype = x.dtype
x_quantized, scale_x = act_quant(x, self.block_size)
y = fp8_gemm(x_quantized, scale_x, self.weight, self.weight_scale_inv)
return y.to(dtype=orig_dtype)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None:
w = self.load_weight(device=device)
### TODO fit weight_inv format
if isinstance(w, tuple):
self.weight = w[0].to(device)
self.weight_scale_inv = w[1].to(device)
self.has_bias = False
else:
raise ValueError("Invalid weight type")
self.weight = self.weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
def unload(self):
if self.weight is not None:
self.weight = None
if self.has_bias:
self.bias = None
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
......@@ -404,7 +465,8 @@ class KLinearCPUInfer(KLinearBase):
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer
"KLinearCPUInfer": KLinearCPUInfer,
"KLinearFP8": KLinearFP8,
}
class KTransformersLinear(BaseInjectedModule, KLinearBase):
......@@ -440,10 +502,11 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
def forward(self, x):
if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized"
return self.prefill_linear.forward(x)
y = self.prefill_linear.forward(x)
else:
assert self.generate_linear is not None, "gpu linear is not initialized"
return self.generate_linear.forward(x)
y = self.generate_linear.forward(x)
return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
if not mode:
......
......@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState
from ktransformers.util.utils import InferenceState, get_compute_capability
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import (
......@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag:
causal_mask = None
else:
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
else:
causal_mask = None
# embed positions
hidden_states = inputs_embeds
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
\ No newline at end of file
......@@ -293,6 +293,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
absorb_for_prefill: False
# GPU 1: layers 15–29
- match:
......@@ -302,6 +303,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
absorb_for_prefill: False
# GPU 2: layers 30–44
- match:
......@@ -311,6 +313,7 @@
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
absorb_for_prefill: False
# GPU 3: layers 45–60
- match:
......@@ -320,6 +323,7 @@
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
absorb_for_prefill: False
# === Overall Model Replacement with Transfer Map ===
......
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......@@ -168,5 +168,5 @@
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_device: "cuda:1"
prefill_device: "cuda:1"
......@@ -60,6 +60,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
......
......@@ -53,6 +53,17 @@
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
# if want to use more VRAM, use experts Marlin and disable CUDA Graph(disable CUDA Graph may cause low performance)
#- match:
# name: "^model\\.layers\\..*\\.mlp\\.experts$"
# replace:
# class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
# kwargs:
# prefill_device: "cuda"
# prefill_op: "KExpertsTorch"
# generate_device: "cuda"
# generate_op: "KExpertsMarlin"
# recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
......
......@@ -12,8 +12,8 @@ from ktransformers.server.config.config import Config
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import check_link_response
from ktransformers.server.backend.base import BackendInterfaceBase
router = APIRouter(prefix='/api')
router = APIRouter(prefix='/api')
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-completion
class OllamaGenerateCompletionRequest(BaseModel):
......@@ -40,61 +40,121 @@ class OllamaGenerateCompletionRequest(BaseModel):
keep_alive: Optional[str] = Field(
"5m", description="Controls how long the model will stay loaded into memory following the request.")
class OllamaGenerationStreamResponse(BaseModel):
model: str
created_at: str
response: str
done: bool = Field(...)
class OllamaGenerationResponse(BaseModel):
pass
@router.post("/generate", tags=['ollama'])
async def generate(request: Request, input: OllamaGenerateCompletionRequest):
id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
print(f'COMPLETION INPUT:----\n{input.prompt}\n----')
config = Config()
if input.stream:
async def inner():
async for token in interface.inference(input.prompt,id):
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response=token,done=False)
yield d.model_dump_json()+'\n'
# d = {'model':config.model_name,'created_at':"", 'response':token,'done':False}
# yield f"{json.dumps(d)}\n"
# d = {'model':config.model_name,'created_at':"", 'response':'','done':True}
# yield f"{json.dumps(d)}\n"
d = OllamaGenerationStreamResponse(model=config.model_name,created_at=str(datetime.now()),response='',done=True)
yield d.model_dump_json()+'\n'
return check_link_response(request,inner())
async for token in interface.inference(input.prompt, id):
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response=token,
done=False
)
yield d.model_dump_json() + '\n'
d = OllamaGenerationStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
response='',
done=True
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError
# https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion
class OllamaChatCompletionMessage(BaseModel):
role: str
content: str
class OllamaChatCompletionRequest(BaseModel):
pass
model: str = Field(..., description="The model name, which is required.")
messages: List[OllamaChatCompletionMessage] = Field(
..., description="A list of messages to generate a response for.")
stream: bool = Field(True, description="If true, the response will be streamed.")
class OllamaChatCompletionStreamResponse(BaseModel):
pass
model: str
created_at: str
message: dict
done: bool = Field(...)
total_duration: Optional[int] = Field(None, description="Total time spent in nanoseconds")
load_duration: Optional[int] = Field(None, description="Time spent loading model in nanoseconds")
prompt_eval_count: Optional[int] = Field(None, description="Number of tokens in prompt")
prompt_eval_duration: Optional[int] = Field(None, description="Time spent evaluating prompt in nanoseconds")
eval_count: Optional[int] = Field(None, description="Number of tokens generated")
eval_duration: Optional[int] = Field(None, description="Time spent generating response in nanoseconds")
class OllamaChatCompletionResponse(BaseModel):
pass
@router.post("/chat", tags=['ollama'])
async def chat(request: Request, input: OllamaChatCompletionRequest):
raise NotImplementedError
id = str(uuid4())
interface: BackendInterfaceBase = get_interface()
config = Config()
# 将消息转换为提示字符串
prompt = ""
for msg in input.messages:
prompt += f"{msg.role}: {msg.content}\n"
prompt += "assistant:"
if input.stream:
async def inner():
start_time = time() # 记录开始时间(秒)
eval_count = 0 # 统计生成的 token 数量
tokens = []
async for token in interface.inference(prompt, id):
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={"role": "assistant", "content": token},
done=False
)
yield d.model_dump_json() + '\n'
# 计算性能数据
end_time = time()
total_duration = int((end_time - start_time) * 1_000_000_000) # 转换为纳秒
prompt_eval_count = len(prompt.split()) # 简单估算提示词数量
eval_duration = total_duration # 假设全部时间用于生成(简化)
prompt_eval_duration = 0 # 假设无单独提示评估时间
load_duration = 0 # 假设加载时间未知
d = OllamaChatCompletionStreamResponse(
model=config.model_name,
created_at=str(datetime.now()),
message={},
done=True,
total_duration=total_duration,
load_duration=load_duration,
prompt_eval_count=prompt_eval_count,
prompt_eval_duration=prompt_eval_duration,
eval_count=eval_count,
eval_duration=eval_duration
)
yield d.model_dump_json() + '\n'
return check_link_response(request, inner())
else:
raise NotImplementedError("Non-streaming chat is not implemented.")
# https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models
class OllamaModel(BaseModel):
......@@ -103,9 +163,8 @@ class OllamaModel(BaseModel):
size: int
# TODO: fill the rest correctly
# mock ollama
@router.get("/tags",tags=['ollama'])
@router.get("/tags", tags=['ollama'])
async def tags():
config = Config()
# TODO: fill this correctly, although it does not effect Tabby
......@@ -138,25 +197,21 @@ class OllamaShowResponse(BaseModel):
class Config:
protected_namespaces = ()
@router.post("/show", tags=['ollama'])
async def show(request: Request, input: OllamaShowRequest):
config = Config()
# TODO: Add more info in config to return, although it does not effect Tabby
return OllamaShowResponse(
modelfile = "# Modelfile generated by ...",
parameters = " ",
template = " ",
details = OllamaShowDetial(
parent_model = " ",
format = "gguf",
family = " ",
families = [
" "
],
parameter_size = " ",
quantization_level = " "
modelfile="# Modelfile generated by ...",
parameters=" ",
template=" ",
details=OllamaShowDetial(
parent_model=" ",
format="gguf",
family=" ",
families=[" "],
parameter_size=" ",
quantization_level=" "
),
model_info = OllamaModelInfo()
model_info=OllamaModelInfo()
)
\ No newline at end of file
......@@ -14,6 +14,7 @@ from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled, MLAWrapperSingleton
warm_uped = False
......@@ -35,9 +36,9 @@ class KTransformersInterface(TransformersInterface):
with torch.device("meta"):
self.model = custom_models[config.architectures[0]](config)
if default_args.optimize_config_path is None:
optimize_rule_path = default_optimize_rules[config.architectures[0]]
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_rule_path = args.optimize_config_path
optimize_config_path = args.optimize_config_path
# print(optimize_config)
......@@ -47,7 +48,7 @@ class KTransformersInterface(TransformersInterface):
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
......@@ -186,6 +187,8 @@ class KTransformersInterface(TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
torch.cuda.set_device(device)
if flashinfer_enabled:
MLAWrapperSingleton.need_plan_all()
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
......@@ -198,6 +201,8 @@ class KTransformersInterface(TransformersInterface):
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
if flashinfer_enabled:
MLAWrapperSingleton.reset_buffer()
self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
......
......@@ -333,14 +333,14 @@ class TransformersInterface(BackendInterfaceBase):
for i in range(1, self.args.max_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
if i > 1 and flashinfer_enabled:
if flashinfer_enabled:
MLAWrapperSingleton.plan_all(None,None,None,self.active_cache_position.to(torch.int32)+1,
num_heads=self.model.config.num_attention_heads, head_dim_ckv=self.model.config.kv_lora_rank,
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size,
sm_scale=(self.model.config.qk_rope_head_dim + self.model.config.qk_nope_head_dim) ** (-0.5), q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
next_token = self.decode_one_tokens()
self.profiler.inc("decode")
if next_token == self.tokenizer.eos_token_id:
if next_token == self.tokenizer.eos_token_id or "<|im_end|>" == self.tokenizer.decode(next_token):
assert self.args.batch_size == 1
break
yield self.append_new_tokens(next_token)
......
......@@ -173,8 +173,8 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="API Generate Tester")
parser.add_argument("--concurrent", type=int, default=1000, help="Number of concurrent evaluations")
parser.add_argument("--file", type=str, default="TIGER-Lab/MMLU-Pro", help="Path to the mmlu.jsonl file")
parser.add_argument("--result", type=str, default="./mmlu_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_pro.log", help="Path to save the log file")
parser.add_argument("--result", type=str, default="./mmlu_result_pro.json", help="Path to save the result JSON file")
parser.add_argument("--log", type=str, default="./mmlu_result_pro.log", help="Path to save the log file")
parser.add_argument("--model", type=str, default="Pro/deepseek-ai/DeepSeek-V3", help="Model name or path")
parser.add_argument("--api_url", type=str, default="http://localhost:15488/v1/chat/completions", help="API URL")
# parser.add_argument("--api_url", type=str, default="https://api.siliconflow.cn/v1/chat/completions", help="API URL")
......
import torch
import torch.nn.functional as F
from typing import Optional
import pytest
from typing import Tuple, Optional, Literal
import time
# use dir path
import os
import sys
sys.path.insert(0, "/home/azure/ktransformers")
print(sys.path)
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from safetensors import safe_open
world_size = 1
rank = 0
block_size = 128
gemm_impl: Literal["bf16", "fp8"] = "bf16"
# Assuming `fp8_gemm`, `act_quant`, `weight_dequant` and other relevant functions are already defined
def test_fp8_gemm_vs_torch_matmul():
# Test case 1: Create random matrices of size (M, K) and (K, N)
M, K, N = 64, 128, 256 # Matrix dimensions
x = torch.randn(M, K, dtype=torch.bfloat16, device='cuda')
weight = torch.randn(N, K, dtype=torch.bfloat16, device='cuda')
# Apply act_quant to both matrices
x_quantized, scale_x = act_quant(x, block_size)
weight_quantized, scale_w = act_quant(weight, block_size)
# mk continous
x_quantized = x_quantized.contiguous()
weight_quantized = weight_quantized.contiguous()
scale_x = scale_x.contiguous()
scale_w = scale_w.contiguous()
# Perform fp8_gemm using the quantized tensors
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight_quantized, scale_w)
# Perform torch.matmul using the original floating point tensors
result_torch_matmul = torch.matmul(x, weight.T)
print(f'result_torch_matmul: {result_torch_matmul.shape}')
print(f'result_fp8_gemm: {result_fp8_gemm.shape}')
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
print(f"result_torch_matmul:\n {result_torch_matmul}")
def test_fp8_gemm_vs_torch_matmul_load():
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
with safe_open(file_path, framework="pt", device=0) as f:
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
# weight_dequant
weight_dequantized = weight_dequant(weight, scale)
print(f"weight_dequantized: {weight_dequantized.shape}")
N, K = weight_dequantized.shape
M = 64
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
x_quantized, scale_x = act_quant(x, block_size)
# Test case 1: quantized x matmal with undequantized weight
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
print(f"result_fp8_gemm:\n {result_fp8_gemm}")
print(f"dtype {result_fp8_gemm.dtype}")
# Perform torch.matmul using the original floating point tensors
result_torch_matmul = torch.matmul(x, weight_dequantized.to(torch.bfloat16).T)
print(f"result_torch_matmul:\n {result_torch_matmul}")
def test_fp8_gemm_tplops():
file_path = "/mnt/data/model/DeepSeek-V3/model-00001-of-000163.safetensors"
with safe_open(file_path, framework="pt", device=0) as f:
weight = f.get_tensor("model.layers.0.mlp.down_proj.weight")
scale = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale_inv")
# weight_dequant
weight_dequantized = weight_dequant(weight, scale)
print(f"weight_dequantized: {weight_dequantized.shape}")
N, K = weight_dequantized.shape
M = 6400
x = torch.randn(2 ,M, K, dtype=torch.bfloat16, device='cuda')
# x_quantized, scale_x = act_quant(x, block_size)
# Calculate time for 1000 fp8_gemm
i = 10
flops_per_gemm = 2 * M * N * K
total_flops = i * flops_per_gemm
x_quantized, scale_x = act_quant(x, block_size)
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
x_quantized, scale_x = act_quant(x, block_size)
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
t0 = time.time()
torch.cuda.synchronize()
for i in range(i):
x_quantized, scale_x = act_quant(x, block_size)
result_fp8_gemm = fp8_gemm(x_quantized, scale_x, weight, scale)
torch.cuda.synchronize()
t1 = time.time()
total_time = t1 - t0
tflops = total_flops / total_time / 1e12
print(f"total_time: {total_time}")
print(f"tflops: {tflops}")
if __name__ == "__main__":
test_fp8_gemm_vs_torch_matmul()
test_fp8_gemm_vs_torch_matmul_load()
test_fp8_gemm_tplops()
\ No newline at end of file
......@@ -25,6 +25,7 @@ import os
from enum import IntEnum
import torch
import KTransformersOps
from .custom_loader import SafeTensorLoader
import ctypes
class GGMLQuantizationType(IntEnum):
......@@ -128,6 +129,7 @@ GGML_BLOCK_SIZES = {
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
"IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
"FP8": 1,
}
GGML_ELEMENTS_PER_BLOCK = {
......@@ -143,6 +145,7 @@ GGML_ELEMENTS_PER_BLOCK = {
"Q5_K": 256,
"Q6_K": 256,
"IQ4_XS": 256,
"FP8": 1,
}
DATA_TYPES = {
......@@ -159,6 +162,7 @@ DATA_TYPES = {
"uint64": 10,
"int64": 11,
"float64": 12,
"FP8": 13,
}
class GGUFLoader:
......@@ -166,12 +170,15 @@ class GGUFLoader:
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
safetensor_loader: SafeTensorLoader
def __init__(self, gguf_path: str):
# Check dir exist
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
if os.path.isfile(gguf_path):
gguf_path = os.path.dirname(gguf_path)
self.safetensor_loader = None
self.tensor_info = {}
self.gguf_path = gguf_path
......@@ -179,7 +186,13 @@ class GGUFLoader:
self.file_data_map = {}
self.gguf_file_meta = {}
self.tensor_device_map = {}
# I know this is ugly, but I don't want to change the original code too much
# TODO: merge gguf load and other loads.
safetensor_loader = SafeTensorLoader(gguf_path)
if safetensor_loader.tensor_file_map:
self.safetensor_loader = safetensor_loader
return
# Walk through all the .gguf files in the directory
found_gguf = False
for root, dirs, files in os.walk(gguf_path):
......@@ -286,6 +299,13 @@ class GGUFLoader:
itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count]
def get_undequanted_tensor_and_ggml_type(self, name):
t = self.tensor_info[name]
data = self.get_mmap_tensor(name)
ggml_type = t["ggml_type"]
data = torch.from_numpy(data)
return data, ggml_type
def load_expert_tensor(self, name, data, expert_id, elements_per_expert, device = "cuda", target_dtype = torch.get_default_dtype())->torch.Tensor:
t = self.tensor_info[name]
if device.lower() == "cpu":
......@@ -310,6 +330,8 @@ class GGUFLoader:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values.copy())
if ggml_name == "BF16":
values = values.view(torch.bfloat16)
values = values.view(shape[-2::-1])
return values
......@@ -418,6 +440,9 @@ def read_value(f, data_type):
elem_type, count = struct.unpack("<IQ", f.read(4 + 8))
return [read_value(f, elem_type) for _ in range(count)]
elif data_type == DATA_TYPES["FP8"]:
return struct.unpack("<B", f.read(1))[0]
else:
raise NotImplementedError(f"Data type {data_type} not implemented")
......
import struct
import warnings
import numpy as np
import re
import numpy.typing as npt
from typing import Sequence
import os
from enum import IntEnum
import torch
import KTransformersOps
from safetensors import safe_open
from ktransformers.ktransformers_ext.triton.fp8gemm import fp8_gemm, act_quant, weight_dequant
from safetensors.torch import save_file
class SafeTensorLoader:
tensor_file_map = {}
tensor_type_map = {}
file_handle_map = {}
def __init__(self, file_path: str):
self.__load_tensor_file_map(file_path)
def __load_tensor_file_map(self, file_path: str):
# 处理传入路径,确保是文件夹路径
if not os.path.exists(file_path):
raise FileNotFoundError(f"Path not found: {file_path}")
if os.path.isfile(file_path):
folder_path = os.path.dirname(file_path)
else:
folder_path = file_path
found_safetensor = False
for root, _, files in os.walk(folder_path):
files = sorted(files)
for file in files:
if file.endswith(".safetensors"):
found_safetensor = True
file_path = os.path.join(root, file)
if file not in self.file_handle_map:
try:
handle = safe_open(file_path, framework="pt")
self.file_handle_map[file] = handle
except Exception as e:
print(f"Error opening Safetensor file {file_path}: {e}")
continue
f = self.file_handle_map.get(file)
if f is None:
continue
try:
for key in f.keys():
self.tensor_file_map[key] = file
except Exception as e:
print(f"Error reading Safetensor file {file_path}: {e}")
# if not found_safetensor:
# raise FileNotFoundError(f"No Safetensor files found in {folder_path}")
def load_tensor(self, key: str, device: str="cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key)
return tensor.to(device)
def close_all_handles(self):
for handle in self.file_handle_map.values():
handle.close()
self.file_handle_map.clear()
def load_dequantized_tensor(self, key:str, device: str="cpu"):
if key not in self.tensor_file_map:
raise KeyError(f"Key {key} not found in Safetensor files")
file = self.tensor_file_map[key]
f = self.file_handle_map.get(file)
if f is None:
raise FileNotFoundError(f"File {file} not found in Safetensor files")
tensor = f.get_tensor(key).to(device)
if key.endswith(".weight"):
if key[:-7] + ".weight_scale_inv" in self.tensor_file_map:
weight_scale_inv = f.get_tensor(key[:-7] + ".weight_scale_inv").to(device)
tensor = weight_dequant(tensor, weight_scale_inv)
return tensor.to(device)
\ No newline at end of file
......@@ -21,6 +21,18 @@ from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
warm_uped = False
def get_compute_capability(device:torch.device = None):
if torch.cuda.is_available():
if device is None:
num_gpus = torch.cuda.device_count()
min_compute_capability_major = 100
for gpu_id in range(num_gpus):
gpu_props = torch.cuda.get_device_properties(gpu_id)
min_compute_capability_major = min(min_compute_capability_major, gpu_props.major)
return min_compute_capability_major
else:
return torch.cuda.get_device_properties(device)
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
......@@ -66,13 +78,22 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str
for name, param in local_state.items():
key = prefix + name
translated_key = translate_name_to_gguf(key)
if translated_key in gguf_loader.tensor_file_map:
# TODO: Merge all loader.
# I know this is ugly but lets do it for now.
if gguf_loader.safetensor_loader is not None:
load_dequantized_tensor = gguf_loader.safetensor_loader.load_dequantized_tensor
tensor_file_map = gguf_loader.safetensor_loader.tensor_file_map
else:
load_dequantized_tensor = gguf_loader.load_gguf_tensor
tensor_file_map = gguf_loader.tensor_file_map
if translated_key in tensor_file_map:
target_dtype = torch.get_default_dtype()
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
print(f"loading {translated_key} to {device}")
torch.cuda.empty_cache()
# device = "cpu" if "embd" in translated_key else "cuda"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
set_param(module, name, weights)
del weights
else:
......@@ -154,6 +175,10 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
inputs_embeds = model.model.embed_tokens(inputs.to("cpu"))
else:
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to(torch_device)
if use_flashinfer_mla:
MLAWrapperSingleton.update_buffer(past_key_values.max_pages)
MLAWrapperSingleton.need_plan_all()
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone().to(torch_device)
......@@ -176,6 +201,9 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
else:
next_token = torch.argmax(next_token_scores, dim=-1)
first_token_time = time.time() - start_time
if use_flashinfer_mla:
MLAWrapperSingleton.reset_buffer()
prefill_count = seq_length
prefill_time = first_token_time
......@@ -193,15 +221,15 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
start_time = time.time()
for i in range(1, max_new_tokens):
if use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
global warm_uped
if use_cuda_graph and ( (warm_uped == True and int(i) == 1) or (warm_uped == False and int(i) == 2) ):
warm_uped = True
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, torch_device, return_dict=False, use_cache=True)
if i > 1 and use_flashinfer_mla:
MLAWrapperSingleton.plan_all(None,None,None,position_ids.squeeze(1)+1,
num_heads, head_dim_ckv, head_dim_kpe, past_key_values.page_size,
q_head_dim ** (-0.5), torch.bfloat16, torch.bfloat16)
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, use_cuda_graph).to(torch_device)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment