Unverified Commit 7527619f authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #122 from kvcache-ai/feat-DeepSeekV3

[Feat] add support to DeepSeekV3
parents f4903d54 6f0fe953
from typing import Any, Union
import numpy as np
import numpy.typing as npt
from torch import Tensor, nn
import torch.nn.functional as F
import torch
import sys, os
from ktransformers.operators.base_operator import BaseInjectedModule
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod
import time
# class Base(BaseInjectedModule, ABC):
class KMoEGateBase(ABC):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
**kwargs):
# super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
super().__init__()
self.key = key
self.gguf_loader = gguf_loader
self.config = config
self.device = device
self.orig_module = orig_module
@abstractmethod
def forward(self, input_tensor, expert_ids, weights):
pass
@abstractmethod
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu", warmup: bool = False):
pass
@abstractmethod
def unload():
pass
def load_weights(self, override_key: str | None = None, device: str = "cpu"):
res = {}
if override_key is not None:
keys = override_key
else:
keys = [self.key]
gate = None
up = None
down = None
gate_type = None
up_type = None
down_type = None
for key in keys:
key = ".".join(key.split(".")[:-1])
if key + ".ffn_gate_inp.weight" in self.gguf_loader.tensor_info:
targets = [".ffn_gate_inp.weight", ".exp_probs_b.bias"]
tensors = self.load_multi(key, targets, device=device)
weight = tensors[".ffn_gate_inp.weight"]
e_score_correction_bias = tensors[".exp_probs_b.bias"]
weight_type = self.gguf_loader.tensor_info[key + ".ffn_gate_inp.weight"]["ggml_type"]
e_score_correction_bias_type = self.gguf_loader.tensor_info[key + ".exp_probs_b.bias"]["ggml_type"]
else:
raise ValueError(f"Experts {key} not found in gguf_loader")
res = {"weight": weight, "e_score_correction_bias": e_score_correction_bias, "weight_type": weight_type, "e_score_correction_bias_type": e_score_correction_bias_type}
return res
def load_multi(self, key: str, keys: list[str], device: str = "cpu"):
tensors = {}
for k in keys:
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
return tensors
class KMoEGate(BaseInjectedModule, KMoEGateBase):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
def forward(self, hidden_states) -> torch.Tensor:
return self.orig_module.forward(hidden_states)
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)
if isinstance(w, dict):
self.weight_type = w["weight_type"]
self.e_score_correction_bias_type = w["e_score_correction_bias_type"]
self.orig_module.weight = nn.Parameter(w["weight"])
self.orig_module.e_score_correction_bias = nn.Parameter(w["e_score_correction_bias"])
else:
raise ValueError("Invalid weight type")
self.orig_module.weight = self.orig_module.weight.to(device)
self.orig_module.e_score_correction_bias = self.orig_module.e_score_correction_bias.to(device)
def unload(self):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
......@@ -54,15 +54,15 @@ class KLinearBase(ABC):
self.has_bias = False
self.dtype = torch.get_default_dtype()
if orig_module is not None:
self.in_features = orig_module.in_features
self.out_features = orig_module.out_features
else:
shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"]
if len(shape) == 1:
print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF")
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
# if orig_module is not None:
# self.in_features = orig_module.in_features
# self.out_features = orig_module.out_features
# else:
shape = self.gguf_loader.tensor_info[key + ".weight"]["shape"]
if len(shape) == 1:
print("Warning: orig_module is not set, but has in_features or out_features equals to 1, can't get in_features and out_features from GGUF")
self.in_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][0]
self.out_features = self.gguf_loader.tensor_info[key + ".weight"]["shape"][1]
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -138,10 +138,10 @@ class KLinearTorch(KLinearBase):
if w is None: w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
self.w = w.to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.w = w.to(dtype=self.dtype).T
self.has_bias = False
elif isinstance(w, tuple):
self.w = w[0].to(dtype=self.dtype).view(self.out_features, self.in_features).T
self.w = w[0].to(dtype=self.dtype).T
self.bias = w[1].to(dtype=self.dtype)
self.has_bias = True
else:
......@@ -222,7 +222,7 @@ class KLinearMarlin(KLinearBase):
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, x.shape[-1])
x = x.reshape(-1, orig_shape[-1])
marlin_s = self.marlin_s.to(x.dtype)
x = KTransformersOps.gptq_marlin_gemm(
x,
......
......@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if inputs_embeds is None:
org_device = input_ids.device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids = input_ids.to("cpu")
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
input_ids = input_ids.to(org_device)
if cache_position is None:
past_seen_tokens = (
......@@ -639,12 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
if inputs_embeds is None:
org_device = input_ids.device
input_ids = input_ids.to("cpu")
inputs_embeds = self.embed_tokens(input_ids)
input_ids = input_ids.to(org_device)
if per_layer_prefill_flag:
causal_mask = None
else:
......@@ -716,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
self.load_layer_to(decoder_layer, InferenceState.PREFILL)
torch.cuda.empty_cache()
t4 = time.time()
# with open("log.txt", "a") as f:
# f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
......@@ -737,6 +740,7 @@ class KDeepseekV2Model(BaseInjectedModule):
hidden_states = layer_outputs[0]
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
......@@ -744,6 +748,10 @@ class KDeepseekV2Model(BaseInjectedModule):
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# with open("log.txt", "a") as f:
# f.write(f"@@@After layers\n")
# f.write(f"hidden_states={hidden_states}\n")
# f.write(f"hidden_states.shape={hidden_states.shape}\n")
if per_layer_prefill_flag:
t6 = time.time()
......
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:0"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda:1"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda:1"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map:
30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([3456][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoE # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
\ No newline at end of file
......@@ -24,8 +24,8 @@ class KTransformersInterface(TransformersInterface):
self.args = args
torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=args.trust_remote_code)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2"
......@@ -46,51 +46,61 @@ class KTransformersInterface(TransformersInterface):
)
optimize_and_load_gguf(self.model, optimize_rule_path, gguf_path, config)
device_map = self.model.gguf_loader.tensor_device_map
logger.info(f"{args.model_name} loaded from {args.model_dir} to {device_map}")
self.device_map = self.model.gguf_loader.tensor_device_map
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {self.device_map}")
self.cache = StaticCache(
config=self.model.config,
max_batch_size=args.batch_size,
max_cache_len=args.cache_lens,
device=device_map,
device=self.device_map,
dtype=self.model.dtype,
)
logger.info(f"StaticCache (length={args.cache_lens}) created at {device_map}, batch size:{args.batch_size}")
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
# logger.info(f"StaticCache (length={args.cache_lens}), batch size:{args.batch_size}")
try:
self.model.generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
gen_config = GenerationConfig(
max_length=128,
temperature=0.7,
top_p=0.9,
do_sample=True
)
self.model.generation_config = gen_config
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.streamer = TextStreamer(self.tokenizer)
def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
if self.args.use_cuda_graph:
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
self.model,
self.current_ids,
self.active_cache_position.unsqueeze(0),
self.active_cache_position,
self.cache,
main_device=torch_device,
return_dict=False,
use_cache=True,
)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(
self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position
)
self.cache.change_seq_length(1)
torch.cuda.synchronize()
logits = logits[0, -1, :]
return self.logits_to_token(logits)
if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model(
self.current_ids,
self.current_ids.to(torch_device),
cache_position=self.active_cache_position,
past_key_values=self.cache,
attention_mask=mask,
......@@ -102,3 +112,63 @@ class KTransformersInterface(TransformersInterface):
logits = logits[0, -1, :]
return self.logits_to_token(logits)
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1]
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
if is_new:
self.cache.reset()
self.ever_generated_ids.clear()
former_seq_length = 0
self.seq_length = input_ids_length
self.generated_ids = torch.zeros(
self.args.batch_size,
self.seq_length + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
else:
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
mask = torch.ones((1, self.seq_length)).to(device)
if not (type(self) is TransformersInterface):
input_ids = input_ids.to("cpu")
inputs_embeds = self.model.model.embed_tokens(input_ids).to(device)
if self.use_static_cache:
logits = self.model(
inputs_embeds=inputs_embeds,
cache_position=cache_position,
past_key_values=self.cache,
return_dict=False,
use_cache=True,
attention_mask=mask,
)[0]
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)
@property
def active_cache_position(self):
device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
return torch.tensor([self.seq_length - 1], device=device)
\ No newline at end of file
......@@ -134,7 +134,7 @@ class TransformersInterface(BackendInterfaceBase):
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir)
self.model = AutoModelForCausalLM.from_pretrained(args.model_dir, device_map=args.device, use_safetensors=True)
logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
# logger.info(f"{args.model_name} loaded from {args.model_dir} to {args.device}")
self.cache = StaticCache(
config=self.model.config,
......@@ -143,7 +143,7 @@ class TransformersInterface(BackendInterfaceBase):
device=args.device,
dtype=self.model.dtype,
)
logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
# logger.info(f"StaticCache (length={args.cache_lens}) created at {args.device}, batch size:{args.batch_size}")
self.streamer = TextStreamer(self.tokenizer)
......@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
return self.streamer.put(new_tokens)
def logits_to_token(self, logits: torch.Tensor):
logits = logits / self.args.temperature
logits = logits / self.args.temperature if self.args.temperature!=0 else logits
for token_idx in self.ever_generated_ids:
if logits[token_idx] < 0:
......@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else:
raise ValueError("local_messages should be List or str")
......@@ -327,14 +329,14 @@ class TransformersInterface(BackendInterfaceBase):
self.profiler.create_and_start_timer("prefill")
for t in self.prefill(input_ids, self.check_is_new(thread_id)):
if t is not None:
print(t, end="")
print(t, end="",flush=True)
yield t
self.profiler.pause_timer("prefill")
self.profiler.create_and_start_timer("decode")
for t in self.generate():
if t is not None:
print(t, end="")
print(t, end="",flush=True)
yield t
print("")
self.profiler.pause_timer("decode")
......
......@@ -93,6 +93,8 @@ class Config(metaclass=Singleton):
self.model_name: str = self.model.get("name", "")
self.model_device: str = self.model.get("device", "cuda:0")
self.gguf_path: Optional[str] = self.model.get("gguf_path", None)
self.use_cuda_graph = self.model.get("use_cuda_graph", True)
self.trust_remote_code = self.model.get("trust_remote_code", True)
# self.model_cache_lens = self.model.get("cache_lens")
self.optimize_config_path: Optional[str] = self.model.get(
"optimize_config_path", None
......@@ -102,7 +104,7 @@ class Config(metaclass=Singleton):
self.total_context = self.model.get("total_context", 2**18)
self.max_batch_size = self.model.get("max_batch_size", 20 if self.paged else 1)
self.max_chunk_size = self.model.get("max_chunk_size", 2048)
self.max_new_tokens = self.model.get("max_new_tokens", 500)
self.max_new_tokens = self.model.get("max_new_tokens", 2000)
self.json_mode = self.model.get("json_mode", False)
self.healing = self.model.get("healing", False)
self.ban_strings: Optional[list] = self.model.get("ban_strings", None)
......
This diff is collapsed.
fire
transformers
transformers==4.43.2
numpy
torch>=2.3.0
packaging
......
......@@ -278,13 +278,15 @@ class CMakeBuild(BuildExtension):
if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
if hasattr(self, "parallel") and self.parallel:
build_args += [f"-j{self.parallel}"]
print("CMake args:", cmake_args)
build_temp = Path(ext.sourcedir) / "build"
if not build_temp.exists():
build_temp.mkdir(parents=True)
subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True
result = subprocess.run(
["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True , capture_output=True
)
print("Standard output:", result.stdout)
print("Standard error:", result.stderr)
subprocess.run(
["cmake", "--build", ".", *build_args], cwd=build_temp, check=True
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment