Commit c009512a authored by Azure-Tang's avatar Azure-Tang
Browse files

Merge branch 'main' into hip

parents c1f13a69 4f22d726
......@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits[:,-1,:].unsqueeze(0).float()
logits = self.lm_head(hidden_states[:,-1:,:]).float()
loss = None
if labels is not None:
......
......@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.device))
logits = self.lm_head(hidden_states[:,-1:,:])
logits = logits.float()
loss = None
......
......@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim, orig_module.max_position_embeddings, orig_module.base
......@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
......@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
......@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
......@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
# **kwargs,
# ):
# BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
# self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
# )
# self.generate_device = generate_device
# self.prefill_device = prefill_device
......@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
......@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
......
This diff is collapsed.
......@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
nn.Module.__init__(self)
nn.Module.__setattr__(self, "orig_module", orig_module)
object.__setattr__(self, "key", key)
object.__setattr__(self, "gguf_loader", gguf_loader)
object.__setattr__(self, "config", config)
object.__setattr__(self, "device", device)
object.__setattr__(self, "prefill_device", prefill_device)
object.__setattr__(self, "generate_device", generate_device)
object.__setattr__(self, "device", generate_device)
def __getattr__(self, name: str) -> Any:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
......
......@@ -18,6 +18,7 @@ import torch.nn.functional as F
import torch
import sys, os
from ktransformers.operators.base_operator import BaseInjectedModule
from tqdm import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
......@@ -118,6 +119,7 @@ class KExpertsCPU(KExpertsBase):
output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__(
self,
......@@ -131,6 +133,9 @@ class KExpertsCPU(KExpertsBase):
**kwargs
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self.gguf_loader = gguf_loader
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts
self.out_device = out_device
......@@ -154,7 +159,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr = ctypes.addressof(
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
)
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
#print(self.gate_type, self.up_type, self.down_type)
n_routed_experts = self.n_routed_experts
# n_routed_experts = len(self.orig_module)
moe_config = MOEConfig(
......@@ -225,6 +230,7 @@ class KExpertsCPU(KExpertsBase):
return
def load_weights(self, override_key: str | None = None, device: str = "cpu"):
# TODO: support Bias
res = {}
if override_key is not None:
keys = override_key
......@@ -239,7 +245,16 @@ class KExpertsCPU(KExpertsBase):
down_type = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
if self.gguf_loader.safetensor_loader is not None:
# using a temp ugly way to temprary load the tensor
gate = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.weight").numpy()
up = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.weight").numpy()
down = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.weight").numpy()
gate_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_exps.ggml_type").item()
up_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_up_exps.ggml_type").item()
down_type = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_down_exps.ggml_type").item()
elif key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
......@@ -288,6 +303,8 @@ class KExpertsMarlin(KExpertsBase):
self.act_fn = ACT2FN[config.hidden_act]
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
self.device = device
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
# create empty marlin experts according to the number of experts per token
# up
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
......@@ -299,17 +316,34 @@ class KExpertsMarlin(KExpertsBase):
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
if w is None: w = self.load_weights()[self.key]
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
if w is None:
w = self.load_weights()
load_by_experts = True
if load_by_experts:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"):
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device)
self.up_projs[i].load(nn.Parameter(up_weights), device=device)
self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)
self.down_projs[i].load(nn.Parameter(down_weights), device=device)
self.loaded_experts_idx.append(i)
else:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
return
def unload(self):
......@@ -329,20 +363,13 @@ class KExpertsMarlin(KExpertsBase):
gate = None
up = None
down = None
gate_type = None
up_type = None
down_type = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight")
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
......@@ -381,6 +408,7 @@ class KExpertsMarlin(KExpertsBase):
return final_hidden_states.to(dtype=org_dtype, device=org_device)
# untested, CUDA OOM
class KExpertsTorch(KExpertsBase):
expert_num: int
loaded_experts_idx: list[int]
......@@ -402,19 +430,39 @@ class KExpertsTorch(KExpertsBase):
# self.loaded_experts_idx = []
self.act_fn = ACT2FN[config.hidden_act]
self.device = device
self.gate = None
self.up = None
self.donw = None
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
self.gate = [None for _ in range(self.expert_num)]
self.up = [None for _ in range(self.expert_num)]
self.down = [None for _ in range(self.expert_num)]
self.dtype = torch.get_default_dtype()
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)[self.key]
if isinstance(w, dict):
self.gate = w["gate"].to(device=device, dtype=self.dtype)
self.up = w["up"].to(device=device, dtype=self.dtype)
self.down = w["down"].to(device=device, dtype=self.dtype)
if w is None:
w = self.load_weights()
load_by_experts = True
if load_by_experts:
if isinstance(w, dict):
for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"):
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device)
self.up[i] = up_weights
self.gate[i] = gate_weights
self.down[i] = down_weights
else:
if isinstance(w, dict):
for i in range(self.expert_num):
self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype)
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)
self.up = torch.stack(self.up, dim=0)
self.gate = torch.stack(self.gate, dim=0)
self.down = torch.stack(self.down, dim=0)
return
def unload(self):
if self.gate is not None:
......@@ -422,6 +470,25 @@ class KExpertsTorch(KExpertsBase):
self.up = None
self.down = None
def load_weights(self, override_key: str | None = None):
res = {}
if override_key is not None:
keys = override_key
else:
keys = [self.key]
gate = None
up = None
down = None
for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res
def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
org_device = hidden_states_cpu.device
......@@ -478,7 +545,7 @@ class KTransformersExperts(BaseInjectedModule, KExpertsBase):
generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
......@@ -582,7 +649,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
if isinstance(self.experts, KExpertsBase):
y = (
self.moe_on_cpuinfer(
self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
......@@ -601,8 +668,7 @@ class KQwen2MoeSparseMoeBlock(BaseInjectedModule, Qwen2MoeSparseMoeBlock):
return y, router_logits
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs
......@@ -672,7 +738,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
......@@ -692,8 +758,7 @@ class KDeepseekV2MoE(BaseInjectedModule, DeepseekV2MoE):
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs
......@@ -773,7 +838,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y_ = self.shared_experts(identity).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
......@@ -793,8 +858,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs
......@@ -881,7 +945,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
if isinstance(self.experts, KExpertsBase):
y = (
self.moe_on_cpuinfer(
self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
......@@ -900,8 +964,7 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
return y, router_logits
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs
......
This diff is collapsed.
......@@ -67,7 +67,14 @@ class KMoEGateBase(ABC):
for key in keys:
key = ".".join(key.split(".")[:-1])
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
if self.gguf_loader.safetensor_loader is not None:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
weight = self.gguf_loader.safetensor_loader.load_tensor(key + ".ffn_gate_inp.weight")
e_score_correction_bias = self.gguf_loader.safetensor_loader.load_tensor(key + ".exp_probs_b.bias")
weight_type = weight.dtype
e_score_correction_bias_type = e_score_correction_bias.dtype
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
elif key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
tensors = self.load_multi(key, targets, device=device)
weight = tensors[".ffn_gate_inp.weight"]
......@@ -93,11 +100,11 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
......@@ -116,8 +123,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device)
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
self.orig_module.weight = nn.Parameter(self.orig_module.weight.to(device))
self.orig_module.e_score_correction_bias = nn.Parameter(self.orig_module.e_score_correction_bias.to(device))
def unload(self):
if self.weight is not None:
......
This diff is collapsed.
......@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState
from ktransformers.util.utils import InferenceState, get_compute_capability
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import (
......@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag:
causal_mask = None
else:
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
else:
causal_mask = None
# embed positions
hidden_states = inputs_embeds
......
This diff is collapsed.
......@@ -126,6 +126,8 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
gguf_loader=GGUFLoader(gguf_path)
with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader)
# pre load lm_head because its big inter result
load_weights(module.lm_head, gguf_loader, "lm_head.")
load_weights(module, gguf_loader)
module.gguf_loader = gguf_loader
del_meta(module)
......
......@@ -219,8 +219,20 @@
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)|(^lm_head)"
name: "(^model\\.layers\\.([5][0-9]|[4][5-9])\\.)|(^model.norm)"
replace:
class: "default"
kwargs:
......
......@@ -118,7 +118,18 @@
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
......
......@@ -15,6 +15,18 @@
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
......@@ -118,7 +118,18 @@
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)"
replace:
class: "default"
kwargs:
......
......@@ -15,6 +15,18 @@
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
class: torch.nn.Linear
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment