"vscode:/vscode.git/clone" did not exist on "483d8fae63f0102a31e9842a593f462399116fbd"
Commit 25cee581 authored by Atream's avatar Atream
Browse files

add balance-serve, support concurrence

parent 8d0292aa
......@@ -359,3 +359,56 @@ class DynamicNTKScalingRotaryEmbedding(
self.orig_module.rope_type,
self.orig_module.config,
)
class RotaryEmbeddingV4(BaseInjectedModule):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def load(self):
self._init(
dim=self.config.qk_rope_head_dim,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
device=self.device,
)
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
\ No newline at end of file
......@@ -32,7 +32,8 @@ import os
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
if flashinfer_enabled:
from ktransformers.operators.flashinfer_wrapper import MLAWrapperSingleton
from flashinfer.mla import BatchMLAPagedAttentionWrapper
from ktransformers.models.custom_cache import KDeepSeekV3Cache
logger = logging.getLogger("attention")
# Copied from transformers.models.llama.modeling_llama.rotate_half
......@@ -759,3 +760,92 @@ class KLlamaAttention(BaseInjectedModule):
attn_weights = None
return attn_output, attn_weights, past_key_value
class flashinfer_attn(BaseInjectedModule, DeepseekV2Attention):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
chunck_size: int = 1000,
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.layer_idx)
self.chunck_size = chunck_size # TODO, generate chunck_size automatically.
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
#del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
def forward(self,
hidden_states: torch.Tensor,
kv_cache: KDeepSeekV3Cache,
position_ids: torch.Tensor,
wrapper: BatchMLAPagedAttentionWrapper,
num_tokens_tensors: torch.Tensor,
page_idx: torch.Tensor,
page_offset: torch.Tensor,
):
q_len, _ = hidden_states.size()
if self.q_lora_rank is None:
q = self.q_proj(hidden_states, num_tokens_tensors)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states, num_tokens_tensors), num_tokens_tensors), num_tokens_tensors)
q = q.view(q_len, self.num_heads, self.q_head_dim)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
compressed_kv = self.kv_a_proj_with_mqa(hidden_states, num_tokens_tensors)
compressed_kv, k_pe = torch.split(
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
compressed_kv = compressed_kv.contiguous()
compressed_kv = self.kv_a_layernorm(compressed_kv, num_tokens_tensors)
k_pe = k_pe.view(q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(q_len, 1, self.kv_lora_rank)
cos, sin = self.rotary_emb(q_pe, position_ids.unsqueeze(0))
q_pe, k_pe = apply_rotary_pos_emb(q_pe.unsqueeze(0), k_pe.unsqueeze(0), cos, sin, unsqueeze_dim=2)
q_pe = q_pe.squeeze(0)
if kv_cache is not None:
# page_idx, page_offset = kv_cache.get_page_table(position_ids, q_indptr, kv_indptr, kv_indices)
cache_kwargs = {"sin": sin, "cos": cos, "page_idx": page_idx, "page_offset": page_offset} # Specific to RoPE models
compressed_kv_with_k_pe = kv_cache.update(compressed_kv.unsqueeze(0), k_pe, self.layer_idx, page_idx, page_offset, cache_kwargs)
compressed_kv = compressed_kv_with_k_pe [:, :, :, :self.kv_lora_rank].view(-1, kv_cache.page_size, self.kv_lora_rank)
k_pe = compressed_kv_with_k_pe [:, :, :, self.kv_lora_rank:].view(-1, kv_cache.page_size, self.qk_rope_head_dim)
q_absorb, out_absorb = self.get_absorbed()
q_nope = q_nope.transpose(0, 1) # q_len is 1, no GPU overhead, same below
q_nope = torch.matmul(q_nope, q_absorb) # batched MM
q_nope = q_nope.transpose(0, 1)
# q_nope.squeeze_(1)
# q_pe.squeeze_(1)
attn_output = wrapper.run(q_nope, q_pe, compressed_kv, k_pe).view(q_len, self.num_heads, self.kv_lora_rank)
attn_output = attn_output.transpose(0, 1)
attn_output = torch.matmul(attn_output, out_absorb.mT) # [self.num_heads, q_len, self.v_head_dim]
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(q_len, self.num_heads * self.v_head_dim)
attn_output = self.o_proj(attn_output, num_tokens_tensors)
return attn_output
\ No newline at end of file
......@@ -37,6 +37,10 @@ import time
from ktransformers.operators.cpuinfer import CPUInfer
def deduplicate_and_sort(lst):
return sorted(set(lst))
#cuda_graphs = [Config().chunk_size]
cuda_graphs = deduplicate_and_sort([1, 2, 3, Config().max_batch_size, 64, Config().chunk_size])
# class Base(BaseInjectedModule, ABC):
class KExpertsBase(ABC):
def __init__(self, key: str, gguf_loader: GGUFLoader, config: PretrainedConfig, orig_module: nn.Module, device: str = "cuda", **kwargs):
......@@ -112,6 +116,7 @@ class KExpertsBase(ABC):
tensors[k] = self.gguf_loader.load_gguf_tensor(key + k, device=device)
return tensors
class KExpertsCPU(KExpertsBase):
input_tensor_cpu:Tensor = None
expert_ids_cpu:Tensor = None
......@@ -119,8 +124,8 @@ class KExpertsCPU(KExpertsBase):
output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu
#gguf_loader:GGUFLoader = None
CPU_INFER = None
# @TODO add yaml
CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__(
self,
key: str,
......@@ -133,11 +138,6 @@ class KExpertsCPU(KExpertsBase):
**kwargs
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
if KExpertsCPU.CPU_INFER is None:
KExpertsCPU.CPU_INFER = CPUInfer(Config().cpu_infer)
#if KExpertsCPU.gguf_loader is None:
# KExpertsCPU.gguf_loader = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
self.gguf_loader = gguf_loader
assert device.lower() == "cpu", "KExpertsCPU can only be loaded on CPU"
self.n_routed_experts = n_routed_experts
self.out_device = out_device
......@@ -161,7 +161,7 @@ class KExpertsCPU(KExpertsBase):
down_ptr = ctypes.addressof(
ctypes.cast(self.down.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
)
#print(self.gate_type, self.up_type, self.down_type)
# print(self.gate_qtype, self.up_qtype, self.down_qtype)
n_routed_experts = self.n_routed_experts
# n_routed_experts = len(self.orig_module)
moe_config = MOEConfig(
......@@ -188,43 +188,83 @@ class KExpertsCPU(KExpertsBase):
self.cpu_infer.submit(self.moe.warm_up())
self.cpu_infer.sync()
if self.out_device not in KExpertsCPU.output_gpu_map:
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((self.config.hidden_size), device=self.out_device)
if isinstance(cuda_graphs, list):
KExpertsCPU.output_gpu_map[self.out_device] = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device=self.out_device) for i in range(len(cuda_graphs))]
else:
KExpertsCPU.output_gpu_map[self.out_device] = torch.zeros((cuda_graphs, self.config.hidden_size), device=self.out_device)
if KExpertsCPU.input_tensor_cpu == None:
KExpertsCPU.input_tensor_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
KExpertsCPU.output_cpu = torch.zeros((self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
if isinstance(cuda_graphs, list):
KExpertsCPU.input_tensor_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.expert_ids_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.weights_cpu = [torch.zeros((cuda_graphs[i], num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True) for i in range(len(cuda_graphs))]
KExpertsCPU.output_cpu = [torch.zeros((cuda_graphs[i], self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16) for i in range(len(cuda_graphs))]
KExpertsCPU.bsz_tensor_cpu = [torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True) for i in range(len(cuda_graphs))]
else:
KExpertsCPU.input_tensor_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True)
KExpertsCPU.expert_ids_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.long, pin_memory=True)
KExpertsCPU.weights_cpu = torch.zeros((cuda_graphs, num_experts_per_tok), device="cpu", dtype=torch.float32, pin_memory=True)
KExpertsCPU.output_cpu = torch.zeros((cuda_graphs, self.config.hidden_size), device="cpu", pin_memory=True, dtype=torch.bfloat16)
KExpertsCPU.bsz_tensor_cpu = torch.zeros((1), device="cpu", dtype=torch.int32, pin_memory=True)
def submit_for_one_decode(self, input_tensor, expert_ids, weights):
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(0), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
def sync_for_one_decode(self):
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
def forward(self, input_tensor, expert_ids, weights):
# generate, capture and run cuda graph
# print(expert_ids)
if input_tensor.size(0)==1 and torch.cuda.is_current_stream_capturing():
# TODO: this branch is unreachable, but the shape of input_tensor([1,hidden_size]) and input_tensor_cpu([hidden_size]) is not compatible
#print("capturing experts")
def submit_for_one_decode(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
if bsz_tensor is None:
bsz_tensor = torch.ones(1, device=input_tensor.device, dtype=torch.int32)
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
else:
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(1, expert_ids.size(1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream, self.moe.forward(1, expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
def sync_for_one_decode(self, cuda_graph_idx=0):
if cuda_graph_idx != -1:
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
else:
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream(self.out_device).cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
def forward(self, input_tensor, expert_ids, weights, bsz_tensor=None, cuda_graph_idx=0):
# generate, capture and run cuda graph
# print(expert_ids)
if bsz_tensor is None:
bsz_tensor = torch.tensor([input_tensor.size(0)], device=input_tensor.device, dtype=torch.int32)
if torch.cuda.is_current_stream_capturing():
if cuda_graph_idx != -1:
KExpertsCPU.input_tensor_cpu[cuda_graph_idx].copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu[cuda_graph_idx].copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu[cuda_graph_idx].copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.weights_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.input_tensor_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.output_cpu[cuda_graph_idx].data_ptr(), KExpertsCPU.bsz_tensor_cpu[cuda_graph_idx].data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx].copy_(KExpertsCPU.output_cpu[cuda_graph_idx], non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device][cuda_graph_idx]
else:
KExpertsCPU.input_tensor_cpu.copy_(input_tensor, non_blocking=True)
KExpertsCPU.expert_ids_cpu.copy_(expert_ids, non_blocking=True)
KExpertsCPU.weights_cpu.copy_(weights, non_blocking=True)
KExpertsCPU.bsz_tensor_cpu.copy_(bsz_tensor, non_blocking=True)
self.cpu_infer.submit_with_cuda_stream(torch.cuda.current_stream().cuda_stream, self.moe.forward(expert_ids.size(0), expert_ids.size(-1), KExpertsCPU.expert_ids_cpu.data_ptr(), KExpertsCPU.weights_cpu.data_ptr(), KExpertsCPU.input_tensor_cpu.data_ptr(), KExpertsCPU.output_cpu.data_ptr(), KExpertsCPU.bsz_tensor_cpu.data_ptr()))
self.cpu_infer.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
KExpertsCPU.output_gpu_map[self.out_device].copy_(KExpertsCPU.output_cpu, non_blocking=True)
return KExpertsCPU.output_gpu_map[self.out_device]
else:
input_tensor = input_tensor.contiguous().cpu()
expert_ids = expert_ids.contiguous().cpu()
weights = weights.contiguous().to(torch.float32).cpu()
bsz_tensor = bsz_tensor.contiguous().cpu()
output = torch.empty_like(input_tensor).contiguous()
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr()))
self.cpu_infer.submit(self.moe.forward(expert_ids.size(0), expert_ids.size(1), expert_ids.data_ptr(), weights.data_ptr(), input_tensor.data_ptr(), output.data_ptr(), bsz_tensor.data_ptr()))
self.cpu_infer.sync()
return output.to(device=object.__getattribute__(self, "out_device"))
......@@ -859,6 +899,8 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y += y_
return y
@torch.no_grad()
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
......@@ -1013,4 +1055,178 @@ class KMistralSparseMoEBlock(BaseInjectedModule, MixtralSparseMoeBlock):
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype))
return final_hidden_states
\ No newline at end of file
return final_hidden_states
class KDeepseekV3MoEV2(BaseInjectedModule, DeepseekV3MoE):
def forward(self, hidden_states, bsz_tensor, cuda_graph_idx=0):
identity = hidden_states
orig_shape = hidden_states.shape
sequence_length = orig_shape[1]
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase
if hasattr(self.experts.generate_experts, "submit_for_one_decode") and torch.cuda.is_current_stream_capturing(): # TODO: this branch cause jit bug
self.experts.generate_experts.submit_for_one_decode(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx)
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
y = self.experts.generate_experts.sync_for_one_decode(cuda_graph_idx).unsqueeze(0)
y += y_
y.resize_(*orig_shape)
return y
if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity, bsz_tensor).squeeze(0)
if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight, bsz_tensor, cuda_graph_idx).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
self.moe_infer(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
else:
# TODO may bugs here
y = (
self.moe_infer_simple(hidden_states, topk_idx, topk_weight)
.view(*orig_shape)
.to(device=hidden_states.device)
)
if self.config.n_shared_experts is not None:
y += y_
return y
@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor, bsz_tensor, cuda_graph_idx=0) -> torch.Tensor:
outs = torch.empty_like(x)
outs = self.experts(x, topk_ids, topk_weight, bsz_tensor, cuda_graph_idx)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer_simple(
self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor
) -> torch.Tensor:
"""
x: [num_tokens, hidden_size]
topk_ids, topk_weight: [num_tokens, num_selected_experts]
"""
outs = torch.zeros_like(x)
for token_idx in range(topk_ids.size(0)):
for expert_idx in range(topk_ids.size(1)):
expert = self.experts[topk_ids[token_idx, expert_idx]]
outs[token_idx] += (
expert.forward(x[token_idx]) * topk_weight[token_idx, expert_idx]
)
return outs
@torch.no_grad()
# TODO may bugs here
def moe_infer(self, x, topk_ids, topk_weight):
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert.forward(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return final_out
class KTransformersExpertsV2(BaseInjectedModule, KExpertsBase):
def __init__(self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
prefill_device:str = "cuda",
prefill_op: str | None = "KExpertsTorch",
generate_device: str = "cpu",
generate_op: str | None = "KExpertsCPU",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KExpertsBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
if generate_op is not None:
self.generate_experts = EXPERTS_MAP[generate_op](key, gguf_loader, config, len(orig_module), device=generate_device, **kwargs)
else:
self.generate_experts = None
if prefill_op is not None:
self.prefill_experts = EXPERTS_MAP[prefill_op](key, gguf_loader, config, len(orig_module), device=prefill_device, **kwargs)
else:
self.prefill_experts = None
self.gpu_mlp_type = prefill_op
self.cpu_mlp_type = generate_op
self.mode = InferenceState.UNLOAD
def load(self, w: dict = None, mode: InferenceState = None, warmup: bool = True):
# TODO support w as input
if not mode: mode = InferenceState.GENERATE
if mode == InferenceState.GENERATE:
self.prefill_experts.unload()
self.generate_experts.load(w, warmup=warmup)
self.device = self.generate_experts.device
self.mode = mode
elif mode == InferenceState.PREFILL:
self.generate_experts.unload()
self.prefill_experts.load(w, warmup=warmup)
self.device = self.prefill_experts.device
self.mode = mode
elif mode == InferenceState.UNLOAD:
self.unload()
self.mode = mode
self.device = self.generate_experts.device
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
def unload(self):
if self.generate_experts is not None:
self.generate_experts.unload()
if self.prefill_experts is not None:
self.prefill_experts.unload()
self.device = self.generate_experts.device
def forward(self, input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx=0):
if self.mode == InferenceState.GENERATE:
assert self.generate_experts is not None, "generate_experts is None"
return self.generate_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
elif self.mode == InferenceState.PREFILL:
assert self.prefill_experts is not None, "prefill_experts is None"
return self.prefill_experts.forward(input_tensor, expert_ids, weights, bsz_tensor, cuda_graph_idx)
else:
raise ValueError("load or set_inference_mode before forward")
def set_inference_mode(self, mode: InferenceState):
if mode == InferenceState.GENERATE:
self.load(mode=InferenceState.GENERATE, warmup=False)
elif mode == InferenceState.PREFILL:
self.load(mode=InferenceState.PREFILL, warmup=False)
elif mode == InferenceState.UNLOAD:
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
......@@ -86,6 +86,7 @@ class MLAWrapper():
self.qo_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indptr_buf = torch.empty(max_batch_size+1, dtype=torch.int32, device=device)
self.kv_indices_buf = torch.empty(max_pages, dtype=torch.int32, device=device)
self.batch_size_tensor_buf = torch.tensor([self.max_batch_size], dtype=torch.int32, device=device)
self.kv_len_arr_buf = torch.empty(max_batch_size, dtype=torch.int32, device=device)
else:
self.qo_indptr_buf = None
......@@ -94,19 +95,22 @@ class MLAWrapper():
self.kv_len_arr_buf = None
self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.float_workspace_buffer,
use_cuda_graph=False,
use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,
kv_indptr=self.kv_indptr_buf,
kv_indices=self.kv_indices_buf,
kv_len_arr=self.kv_len_arr_buf,
bsz_tensor=self.batch_size_tensor_buf
)
self.need_plan = True
def plan(self,
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
bsz_tensor,
num_heads,
head_dim_ckv,
head_dim_kpe,
......@@ -138,6 +142,7 @@ class MLAWrapper():
sm_scale,
q_data_type,
kv_data_type,
bsz_tensor
)
def run(self, q_nope, q_pe, ckv, k_pe, return_lse = False):
......@@ -240,16 +245,17 @@ if __name__ == "__main__":
#checksame()
#exit(0)
max_batch_size = 1
max_pages = 64
max_batch_size = 2
max_batch_tokens = 256
max_pages = 128
page_size = 64
num_heads = 128
# warm-up
kv_len = 4023
q_len = 1
q_nope_buf = torch.randn((q_len, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe_buf = torch.randn((q_len, num_heads, 64), dtype=torch.bfloat16, device="cuda")
q_nope_buf = torch.randn((max_batch_tokens, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_pe_buf = torch.randn((max_batch_tokens, num_heads, 64), dtype=torch.bfloat16, device="cuda")
kv_buf = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
ckv, k_pe = torch.split(kv_buf, [512, 64], dim=-1)
......@@ -260,13 +266,19 @@ if __name__ == "__main__":
max_pages,
)
used_pages = (kv_len + page_size - 1)// page_size
kv_len_arr = torch.tensor([kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len], dtype=torch.int32, device="cuda")
kv_indptr = torch.tensor([0, used_pages], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
kv_indices[:used_pages] = torch.arange(0, used_pages, dtype=torch.int32, device="cuda")
bsz_tensor = torch.tensor([1], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
None,
None,
kv_indptr,
kv_indices,
kv_len_arr,
bsz_tensor,
128,
512,
64,
......@@ -276,14 +288,98 @@ if __name__ == "__main__":
torch.bfloat16,
)
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
attn_output = wrapper.run(q_nope_buf[:q_len], q_pe_buf[:q_len], ckv, k_pe)
print(attn_output.shape)
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
attn_output = wrapper.run(q_nope_buf, q_pe_buf, ckv, k_pe)
graph.replay()
q = torch.cat([q_nope_buf, q_pe_buf], dim=-1)
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
1,
q[:q_len],
k[:kv_len],
v[:kv_len],
True,
192 ** (-0.5)
)
torch.testing.assert_close(attn_output[:q_len], attn_ref, rtol=5e-3, atol=5e-3)
# warm-up finished
kv_len = 512
q_len = 128
pages = max_pages
used_pages = (kv_len + page_size - 1)// page_size
q_nope = torch.randn((q_len*2, num_heads, 512), dtype=torch.bfloat16, device="cuda")
q_nope[q_len:] = q_nope[:q_len]
q_pe = torch.randn((q_len*2, num_heads, 64), dtype=torch.bfloat16, device="cuda")
q_pe[q_len:] = q_pe[:q_len]
kv_cache = torch.randn((max_pages, page_size, 576), dtype=torch.bfloat16, device="cuda")
kv_cache[used_pages:2*used_pages] = kv_cache[:used_pages]
ckv, k_pe = torch.split(kv_cache, [512, 64], dim=-1)
kv_len_arr = torch.tensor([kv_len, kv_len], dtype=torch.int32, device="cuda")
qo_indptr = torch.tensor([0, q_len, q_len*2], dtype=torch.int32, device="cuda")
kv_indptr = torch.tensor([0, used_pages, used_pages*2], dtype=torch.int32, device="cuda")
kv_indices = torch.empty(max_pages, dtype=torch.int32, device="cuda")
kv_indices[:2*used_pages] = torch.arange(0, 2*used_pages, dtype=torch.int32, device="cuda")
bsz_tensor = torch.tensor([2], dtype=torch.int32, device="cuda")
wrapper.plan(
qo_indptr,
kv_indptr,
kv_indices,
kv_len_arr,
bsz_tensor,
128,
512,
64,
page_size,
192 ** (-0.5),
torch.bfloat16,
torch.bfloat16,
)
q_nope_buf.copy_(q_nope)
q_pe_buf.copy_(q_pe)
kv_buf[:pages].copy_(kv_cache)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
# ref_torch
q = torch.cat([q_nope, q_pe], dim=-1)
k = (
torch.cat([ckv, k_pe], dim=-1)
.view(-1, 1, 512 + 64)
.repeat_interleave(num_heads, dim=1)
)
v = ckv.view(-1, 1, 512).repeat_interleave(num_heads, dim=1)
attn_ref, lse_ref = attention_ref_torch(
max_batch_size,
q,
k[:2*kv_len],
v[:2*kv_len],
True,
192 ** (-0.5)
)
torch.testing.assert_close(attn_ref[:q_len], attn_ref[q_len:q_len*2], rtol=1e-9, atol=1e-9)
torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
torch.testing.assert_close(attn_output[:q_len], attn_ref[:q_len], rtol=5e-3, atol=5e-3)
torch.testing.assert_close(attn_output[q_len:q_len*2], attn_ref[q_len:q_len*2], rtol=5e-3, atol=5e-3)
#torch.testing.assert_close(attn_output[:q_len], attn_output[q_len:q_len*2], rtol=1e-9, atol=1e-9)
#torch.testing.assert_close(attn_output, attn_ref, rtol=5e-3, atol=5e-3)
exit(0)
for forward_id in range(0, 1):
print("forward_id", forward_id)
for layer_id in range(1):
......@@ -376,5 +472,4 @@ if __name__ == "__main__":
#file_name = f"./flashinfer_output/layer_{layer_id}_forward_{forward_id}_attn_output.pt"
#ktrans_output = torch.load(file_name)
#torch.testing.assert_close(attn_output, ktrans_output.squeeze(1), rtol=1e-3, atol=1e-3)
print("test past")
print("test past")
\ No newline at end of file
......@@ -249,4 +249,4 @@ class KMoEGateDeepSeekV3(BaseInjectedModule, KMoEGateBase):
if self.weight is not None:
self.weight = None
if self.e_score_correction_bias is not None:
self.e_score_correction_bias = None
self.e_score_correction_bias = None
\ No newline at end of file
'''
Date: 2024-11-13 15:05:52
LastEditors: Xie Weiyu ervinxie@qq.com
LastEditTime: 2024-11-25 08:59:19
'''
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Fused operators for normalization layers."""
import logging
from typing import Optional, Tuple, Union
from transformers import PretrainedConfig
import torch
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3RMSNorm
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from flashinfer.norm import (
fused_add_rmsnorm,
rmsnorm,
)
logger = logging.getLogger(__name__)
class RMSNorm(DeepseekV3RMSNorm, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.hidden_size,
orig_module.variance_epsilon)
def forward(
self,
x: torch.Tensor,
batch_size_tensor: torch.Tensor = None,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
#return self.forward_native(x, residual)
if batch_size_tensor is None:
return self.forward_native(x)
if residual is not None:
fused_add_rmsnorm(x, residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
#residual = x + residual
#out = rmsnorm(residual, self.weight.data, batch_size_tensor, self.variance_epsilon)
return x, residual
# print(x.shape, self.weight.data.shape, self.variance_epsilon, x.dtype, self.weight.data.dtype, x.device, self.weight.device, x.is_contiguous(), self.weight.data.is_contiguous())
out = rmsnorm(x, self.weight.data, batch_size_tensor,self.variance_epsilon)
return out
def forward_native(
self, hidden_states
):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
\ No newline at end of file
......@@ -15,14 +15,16 @@ import ctypes
import torch
from torch import Tensor, nn
import KTransformersOps
import vLLMMarlin
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState
from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marlin_utils import (
MarlinWorkspace,
marlin_quantize,
marlin_quantize,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MIN_THREAD_K,
GPTQ_MARLIN_MAX_PARALLEL,
vllm_marlin_quantize
)
from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig
......@@ -84,8 +86,10 @@ class KLinearBase(ABC):
if self.gguf_loader.safetensor_loader is not None:
# using safetensor_loader
tensor = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight')
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
if key+'.weight_scale_inv' in self.gguf_loader.safetensor_loader.tensor_file_map:
weight_scale_inv = self.gguf_loader.safetensor_loader.load_tensor(key+'.weight_scale_inv')
return nn.Parameter(tensor), nn.Parameter(weight_scale_inv)
return nn.Parameter(tensor)
elif key + ".weight" in self.gguf_loader.tensor_file_map:
if key + ".bias" in self.gguf_loader.tensor_file_map:
......@@ -134,7 +138,7 @@ class KLinearTorch(KLinearBase):
self.weight = None
self.has_bias = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
dtype = x.dtype
out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
......@@ -178,7 +182,6 @@ class KLinearTorch(KLinearBase):
if self.has_bias:
self.bias = None
class KLinearQ8(KLinearBase):
def __init__(
self,
......@@ -370,7 +373,7 @@ class KLinearFP8(KLinearBase):
self.dtype = torch.get_default_dtype()
self.block_size = block_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor) -> torch.Tensor:
x = x.to(self.device)
orig_dtype = x.dtype
x_quantized, scale_x = act_quant(x, self.block_size)
......@@ -397,8 +400,152 @@ class KLinearFP8(KLinearBase):
self.weight = None
if self.has_bias:
self.bias = None
# TODO: merge two marlin class
class VLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
g_idx: torch.Tensor
sort_indices: torch.Tensor
has_bias: bool
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cuda",
num_bits: int = 4, # 4-bit/8-bit is supported
group_size: int = 64, # -1, 32, 64, 128
act_order: bool = False,
is_k_full=True,
**kwargs,
):
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.num_bits = num_bits
self.group_size = group_size
self.act_order = act_order
self.is_k_full = is_k_full
self.padding = False
self.orin_in_features = self.in_features
self.orin_out_features = self.out_features
if self.in_features%GPTQ_MARLIN_MIN_THREAD_K!=0 or self.out_features%GPTQ_MARLIN_MIN_THREAD_K!=0:
#print(f"warning!, in_features={in_features} or out_features={out_features} is undivisible by GPTQ_MARLIN_MIN_THREAD_K={GPTQ_MARLIN_MIN_THREAD_K} and GPTQ_MARLIN_MIN_THREAD_N={GPTQ_MARLIN_MIN_THREAD_N}, padding")
self.padding = True
self.in_features = (self.in_features+GPTQ_MARLIN_MIN_THREAD_K-1)//GPTQ_MARLIN_MIN_THREAD_K*GPTQ_MARLIN_MIN_THREAD_K
self.out_features = (self.out_features+GPTQ_MARLIN_MIN_THREAD_N-1)//GPTQ_MARLIN_MIN_THREAD_N*GPTQ_MARLIN_MIN_THREAD_N
#print(f"After padding: in_features={in_features}, out_features={out_features}")
self.k = self.in_features
self.n = self.out_features
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None):
if self.loaded: return
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin quantized linear only supports GPU device"
#if self.in_features * self.out_features:
if w is None:
w = self.load_weight(device=device)
if isinstance(w, nn.Parameter):
# pad weight
weight = w.view(self.orin_out_features, self.orin_in_features).T
self.has_bias = False
elif isinstance(w, tuple):
w = list(w)
weight = w[0].view(self.orin_out_features, self.orin_in_features).T
self.bias = w[1].view(self.orin_out_features)
self.bias = w[1]
self.has_bias = True
else:
raise ValueError("Invalid weight type")
weight = weight.to(device)
if self.has_bias:
self.bias = self.bias.to(device)
if self.padding:
padded_weight = torch.zeros(self.in_features, self.out_features, device=self.device)
padded_weight[:self.orin_in_features, :self.orin_out_features] = weight
weight = padded_weight
# Pack Marlin linear
marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
weight, self.num_bits, self.group_size, self.act_order
)
self.workspace = MarlinWorkspace(
self.out_features, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL,self.device
)
self.weight = marlin_q_w
self.marlin_q_w = marlin_q_w
self.marlin_s = marlin_s
self.g_idx = g_idx
self.sort_indices = sort_indices
self.k = weight.shape[0]
self.n = weight.shape[1]
# self.shape_buffer = torch.tensor([60], dtype=torch.int32, device=self.device)
self.loaded = True
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor = None) -> torch.Tensor:
if bsz_tensor is None:
bsz_tensor = torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device)
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
orig_dtype = x.dtype
x = x.reshape(-1, orig_shape[-1])
marlin_s = self.marlin_s.to(x.dtype)
sms = -1
x = vLLMMarlin.gptq_marlin_gemm(
x,
self.marlin_q_w,
marlin_s,
self.g_idx,
self.sort_indices,
self.workspace.scratch,
self.num_bits,
bsz_tensor,
# torch.tensor([x.shape[0]], dtype=torch.int32, device=self.device),
x.shape[0],
self.n,
x.shape[-1],
sms,
self.is_k_full,
)
# x = KTransformersOps.gptq_marlin_gemm(
# x,
# self.marlin_q_w,
# marlin_s,
# self.g_idx,
# self.sort_indices,
# self.workspace.scratch,
# self.num_bits,
# x.shape[0],
# self.n,
# x.shape[-1],
# self.is_k_full,
# )
if self.has_bias:
x = x + self.bias
orig_shape[-1] = self.n
return x.reshape(orig_shape).to(orig_dtype)
def unload(self):
if self.has_bias:
self.bias = None
self.marlin_q_w = None
self.marlin_s = None
self.g_idx = None
self.sort_indices = None
self.workspace = None
class KLinearMarlin(KLinearBase):
marlin_q_w: torch.Tensor
marlin_s: torch.Tensor
......@@ -483,7 +630,7 @@ class KLinearMarlin(KLinearBase):
self.n = weight.shape[1]
self.loaded = True
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, bsz_tensor: torch.Tensor=None, **kwargs) -> torch.Tensor:
# Only support input x as BF16 and FP16
x = x.to(self.device)
orig_shape = list(x.shape)
......@@ -629,12 +776,13 @@ class KLinearCPUInfer(KLinearBase):
if self.w is not None:
self.w = None
if self.has_bias:
self.bias = None
self.bias = None
LINEAR_MAP = {
"KLinearMarlin": KLinearMarlin,
"KLinearTorch": KLinearTorch,
"KLinearCPUInfer": KLinearCPUInfer,
"VLinearMarlin": VLinearMarlin,
"KLinearFP8": KLinearFP8,
"KLinearQ8": KLinearQ8,
}
......@@ -668,13 +816,13 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.generate_linear = None
self.mode = InferenceState.UNLOAD
def forward(self, x):
def forward(self, x, bsz_tensor=None):
if self.mode == InferenceState.PREFILL:
assert self.prefill_linear is not None, "cpu linear is not initialized"
y = self.prefill_linear.forward(x)
y = self.prefill_linear.forward(x, bsz_tensor)
else:
assert self.generate_linear is not None, "gpu linear is not initialized"
y = self.generate_linear.forward(x)
y = self.generate_linear.forward(x, bsz_tensor)
return y
def load(self, w: dict | nn.Parameter | tuple | None = None, mode: InferenceState = InferenceState.GENERATE):
......@@ -717,3 +865,5 @@ class KTransformersLinear(BaseInjectedModule, KLinearBase):
self.unload()
else:
raise ValueError("mode must be either InferenceState.GENERATE, InferenceState.PREFILL or InferenceState.UNLOAD")
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
from transformers import PretrainedConfig
import torch.nn as nn
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3MLP
class kDeepseekV3MLP(DeepseekV3MLP, BaseInjectedModule):
def __init__(self,
key: str,
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, prefill_device, **kwargs)
self.orig_module.__init__(orig_module.config,
orig_module.hidden_size, orig_module.intermediate_size)
def forward(self, x, bsz_tensor):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, bsz_tensor)) * self.up_proj(x, bsz_tensor), bsz_tensor)
return down_proj
\ No newline at end of file
......@@ -22,7 +22,7 @@
replace:
class: ktransformers.operators.linear.KTransformersLinear
kwargs:
generate_device: "cpu"
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearMarlin"
prefill_op: "KLinearTorch"
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "KLinearFP8"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace:
class: ktransformers.operators.layernorm.RMSNorm
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace:
class: ktransformers.operators.mlp.kDeepseekV3MLP
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
\ No newline at end of file
......@@ -10,7 +10,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -18,7 +18,7 @@
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......
......@@ -10,7 +10,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.KMoEGateDeepSeekV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -66,7 +66,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......@@ -74,7 +74,7 @@
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
class: ktransformers.operators.gate.KMoEGateDeepSeekV3 # mlp module with custom forward function
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
......
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace:
class: ktransformers.operators.layernorm.RMSNorm
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace:
class: ktransformers.operators.mlp.kDeepseekV3MLP
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
......@@ -38,7 +38,7 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
class: ktransformers.operators.gate.KMoEGate
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......
- match:
name: "^lm_head$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.(?!.*self_attn\\.kv_b_proj).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously
replace:
class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
generate_op: "VLinearMarlin"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MoE
replace:
class: ktransformers.operators.experts.KDeepseekV3MoEV2 # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace:
class: ktransformers.operators.experts.KTransformersExpertsV2 # custom MoE Kernel with expert paralleism
kwargs:
prefill_device: "cuda"
prefill_op: "KExpertsTorch"
generate_device: "cpu"
generate_op: "KExpertsCPU"
out_device: "cuda"
recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\..*\\.self_attn$"
replace:
class: ktransformers.operators.attention.flashinfer_attn # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
absorb_for_prefill: False # change this to True to enable long context(prefill may slower).
- match:
name: "^model$"
replace:
class: "ktransformers.operators.models.KDeepseekV2Model"
kwargs:
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match:
name: "^model.embed_tokens"
replace:
class: "default"
kwargs:
generate_device: "cpu"
prefill_device: "cpu"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RMSNorm
replace:
class: ktransformers.operators.layernorm.RMSNorm
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3MLP
replace:
class: ktransformers.operators.mlp.kDeepseekV3MLP
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV4
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
......@@ -38,7 +38,7 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace:
class: ktransformers.operators.gate.KMoEGate
class: ktransformers.operators.gate.KMoEGateDeepSeekV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
......
import argparse
from ktransformers.server.backend.args import ConfigArgs, default_args
from ktransformers.util.utils import get_free_ports
class ArgumentParser:
def __init__(self, cfg):
......@@ -16,20 +16,18 @@ class ArgumentParser:
parser.add_argument("--web", type=bool, default=self.cfg.mount_web)
parser.add_argument("--model_name", type=str, default=self.cfg.model_name)
parser.add_argument("--model_dir", type=str)
parser.add_argument("--model_path", type=str)
parser.add_argument("--model_path", type=str, default=self.cfg.model_path)
parser.add_argument(
"--device", type=str, default=self.cfg.model_device, help="Warning: Abandoning this parameter"
)
parser.add_argument("--gguf_path", type=str, default=self.cfg.gguf_path)
parser.add_argument("--optimize_config_path", default=self.cfg.optimize_config_path, type=str, required=False)
parser.add_argument("--optimize_config_path", default=None, type=str, required=False)
parser.add_argument("--cpu_infer", type=int, default=self.cfg.cpu_infer)
parser.add_argument("--type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_prefill_size", type=int, default=8192)
parser.add_argument("--backend_type", type=str, default=self.cfg.backend_type)
parser.add_argument("--chunk_size", type=int, default=self.cfg.chunk_size)
# model configs
# parser.add_argument("--model_cache_lens", type=int, default=self.cfg.cache_lens) # int?
parser.add_argument("--paged", type=bool, default=self.cfg.paged)
parser.add_argument("--total_context", type=int, default=self.cfg.total_context)
parser.add_argument("--max_batch_size", type=int, default=self.cfg.max_batch_size)
parser.add_argument("--max_new_tokens", type=int, default=self.cfg.max_new_tokens)
parser.add_argument("--json_mode", type=bool, default=self.cfg.json_mode)
......@@ -62,7 +60,6 @@ class ArgumentParser:
parser.add_argument("--repetition_penalty", type=float, default=self.cfg.repetition_penalty)
parser.add_argument("--frequency_penalty", type=float, default=self.cfg.frequency_penalty)
parser.add_argument("--presence_penalty", type=float, default=self.cfg.presence_penalty)
parser.add_argument("--max_response_tokens", type=int, default=self.cfg.max_response_tokens)
parser.add_argument("--response_chunk", type=int, default=self.cfg.response_chunk)
parser.add_argument("--no_code_formatting", type=bool, default=self.cfg.no_code_formatting)
parser.add_argument("--cache_8bit", type=bool, default=self.cfg.cache_8bit)
......@@ -103,6 +100,18 @@ class ArgumentParser:
# local chat
parser.add_argument("--prompt_file", type=str, default=self.cfg.prompt_file)
# async server
parser.add_argument("--sched_strategy", type=str, default=self.cfg.sched_strategy)
# parser.add_argument("--sched_port", type=int, default=self.cfg.sched_port)
# parser.add_argument("--sched_metrics_port", type=int, default=self.cfg.sched_metrics_port)
# parser.add_argument("--kvc2_metrics_port", type=int, default=self.cfg.kvc2_metrics_port)
parser.add_argument("--page_size", type=str, default=self.cfg.page_size)
parser.add_argument("--memory_gpu_only", type=str, default=self.cfg.memory_gpu_only)
parser.add_argument("--utilization_percentage", type=str, default=self.cfg.utilization_percentage)
parser.add_argument("--cpu_memory_size_GB", type=str, default=self.cfg.cpu_memory_size_GB)
args = parser.parse_args()
if (args.model_dir is not None or args.model_path is not None):
if (args.model_path is not None):
......@@ -123,6 +132,15 @@ class ArgumentParser:
self.cfg.mount_web = args.web
self.cfg.server_ip = args.host
self.cfg.server_port = args.port
self.cfg.backend_type = args.type
self.cfg.user_force_think = args.force_think
args.gpu_memory_size = args.cache_lens*2*576*61
self.cfg.gpu_memory_size = args.gpu_memory_size
free_ports = get_free_ports(3, [args.port])
args.sched_port = free_ports[0]
args.sched_metrics_port = free_ports[1]
args.kvc2_metrics_port = free_ports[2]
self.cfg.sched_port = free_ports[0]
self.cfg.sched_metrics_port = free_ports[1]
self.cfg.kvc2_metrics_port = free_ports[2]
return args
......@@ -12,18 +12,10 @@ class ConfigArgs(BaseModel):
class Config:
protected_namespaces = ()
paged: bool = Field(None, description="Whether to use paged attention kv cache")
total_context: int = Field(
None,
description=(
"Total number of tokens to allocate space for. This is not the max_seq_len supported by the model but the"
" total to distribute dynamically over however many jobs are active at once"
),
)
max_batch_size: int = Field(
None, description="Max number of batches to run at once, assuming the sequences will fit within total_context"
)
chunk_prefill_size: int = Field(
chunk_size: int = Field(
None,
description=(
"Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a new"
......@@ -70,7 +62,6 @@ class ConfigArgs(BaseModel):
repetition_penalty: float = Field(None, description="Sampler repetition penalty, default = 1.01 (1 to disable)")
frequency_penalty: float = Field(None, description="Sampler frequency penalty, default = 0.0 (0 to disable)")
presence_penalty: float = Field(None, description="Sampler presence penalty, default = 0.0 (0 to disable)")
max_response_tokens: int = Field(None, description="Max tokens per response, default = 1000")
response_chunk: int = Field(None, description="Space to reserve in context for reply, default = 250")
no_code_formatting: bool = Field(None, description="Disable code formatting/syntax highlighting")
cache_8bit: bool = Field(None, description="Use 8-bit (FP8) cache")
......
......@@ -9,9 +9,11 @@ from ktransformers.server.backend.interfaces.transformers import TransformersThr
from ktransformers.server.backend.interfaces.ktransformers import KTransformersThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaThreadContext
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
class ThreadContextManager:
lock: Lock
threads_context: Dict[ObjectID, ThreadContext]
......@@ -36,7 +38,16 @@ class ThreadContextManager:
elif isinstance(self.interface, TransformersInterface):
new_context = TransformersThreadContext(run, self.interface)
else:
raise NotImplementedError
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeThreadContext
from ktransformers.server.backend.interfaces.balance_serve import BalanceServeInterface
if isinstance(self.interface, BalanceServeInterface):
new_context = BalanceServeThreadContext(run, self.interface)
else:
raise NotImplementedError
# elif isinstance(self.interface, BalanceServeInterface):
# new_context = BalanceServeThreadContext(run, self.interface)
# else:
# raise NotImplementedError
self.threads_context[run.thread_id] = new_context
# self.threads_context[run.thread_id] = ExllamaInferenceContext(run)
re = self.threads_context[run.thread_id]
......
from typing import Any, AsyncIterator, List, Optional, Set
from ktransformers.models.custom_cache import KDeepSeekV3Cache
from transformers import (
AutoTokenizer,
AutoConfig,
GenerationConfig,
StaticCache,
AutoModelForCausalLM,
BitsAndBytesConfig,
)
from ktransformers.server.config.config import Config
from ..base import ThreadContext, BackendInterfaceBase
import torch
from ktransformers.server.backend.interfaces.transformers import (
ConfigArgs,
default_args,
TextStreamer,
)
from ktransformers.server.schemas.base import ObjectID
from ktransformers.server.config.log import logger
from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
from ktransformers.models.custom_modeling_deepseek_v2 import KDeepseekV2ForCausalLM
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
from ktransformers.server.balance_serve.inference.forward_batch import ForwardBatchInput, ForwardBatchOutput
from ktransformers.server.balance_serve.sched_rpc import SchedulerClient
from ktransformers.server.balance_serve.settings import sched_ext
from torch.multiprocessing import Queue
import torch.multiprocessing as mp
from ktransformers.server.schemas.endpoints.chat import RawUsage
from ktransformers.server.utils.multi_timer import Profiler
import zmq
import time
import queue
import tempfile
import asyncio
import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
import os
ktransformer_rules_dir = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "./optimize/optimize_rules/")
)
default_optimize_rules = {
"DeepseekV3ForCausalLM": ktransformer_rules_dir + "DeepSeek-V3-Chat-serve.yaml",
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct-serve.yaml",
}
async def chat_stream(queue: asyncio.Queue, tokenizer: AutoTokenizer):
streamer = TextStreamer(tokenizer)
while True:
token = await queue.get()
#print(f"Got token: {token}")
if token is None:
# str = f'{token}\n\n'
# str = model.tokenizer.decode(token)
s = streamer.end()
if s is not None:
yield s
break
# str = model.tokenizer.decode(token)
yield streamer.put(token)
def fill_generated_tokens(query_updates: list[sched_ext.QueryUpdate], generated_tokens: torch.Tensor, query_manager: QueryManager = None):
#print(len(query_updates), generated_tokens.size(0), generated_tokens)
for i in range(generated_tokens.size(0)):
print(generated_tokens[i].item())
query_updates[i].generated_token = generated_tokens[i].item()
if not query_manager.query_map[query_updates[i].id].is_prefill:
pos = query_updates[i].active_position
query_manager.query_map[query_updates[i].id].query_tokens[pos] = generated_tokens[i]
def report_last_time_performance(profiler: Profiler):
try:
tokenize_time = profiler.get_timer_sec('tokenize')
prefill_time = profiler.get_timer_sec('prefill')
decode_time = profiler.get_timer_sec('decode')
prefill_count = profiler.get_counter('prefill')
decode_count = profiler.get_counter('decode')
logger.info(f'Performance(T/s): prefill {prefill_count/prefill_time}, decode {decode_count/decode_time}. Time(s): tokenize {tokenize_time}, prefill {prefill_time}, decode {decode_time}')
except:
logger.info(f'Performance statistics not recorded')
class Engine:
sched_client : SchedulerClient
updates : list[sched_ext.QueryUpdate]
batch : sched_ext.BatchQueryTodo
model_runner: ModelRunner
sampler: Sampler
query_manager: QueryManager
cache: KDeepSeekV3Cache
def __init__(self, args: ConfigArgs = default_args, generated_token_queue:Queue = None, broadcast_endpoint: str = None):
self.args = args
# 子进程和父进程无法共享 config 变量
for key, value in vars(args).items():
if value is not None and hasattr(Config(), key):
setattr(Config(), key, value)
self.device = self.args.device
self.sched_client = SchedulerClient(args.sched_port)
self.updates = []
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
self.cache = KDeepSeekV3Cache(config, self.args.page_size)
self.gen_queue = generated_token_queue
print(f"Getting inference context from sched_client.")
inference_context = self.sched_client.get_inference_context_raw()
print(f"Got inference context, sending it to subscribers.")
inference_context = self.sched_client.rebuild_inferece_context(inference_context)
self.cache.load(inference_context)
print(f"kv_cache loaded successfully.")
self.block_num = inference_context.k_cache[0].size(1)
with torch.device("meta"):
if config.architectures[0] == "DeepseekV3ForCausalLM":
self.model = KDeepseekV3ForCausalLM(config, self.cache)
elif config.architectures[0] == "DeepseekV2ForCausalLM":
self.model = KDeepseekV2ForCausalLM(config, self.cache)
# print(self.block_num)
context = zmq.Context()
self.pub_socket = context.socket(zmq.PUB)
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
# time.sleep(1) # make sure all subscribers are ready
try:
generation_config = GenerationConfig.from_pretrained(args.model_dir)
except:
generation_config = GenerationConfig(
max_length=args.max_new_tokens,
temperature=args.temperature,
top_p=args.top_p,
do_sample=True
)
if args.optimize_config_path is None:
optimize_config_path = default_optimize_rules[config.architectures[0]]
else:
optimize_config_path = args.optimize_config_path
gguf_path = args.gguf_path
if gguf_path is None:
gguf_path = input(
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
" belong to current model):"
)
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
self.model.generation_config = generation_config
if self.model.generation_config.pad_token_id is None:
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
self.model.eval()
#@TODO add config
self.model.init_wrapper(self.args.use_cuda_graph, self.device, args.max_batch_size, self.block_num)
self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size)
self.sampler = Sampler()
self.query_manager = QueryManager(device = self.device, page_size = args.page_size)
def sampling(self, forward_output: ForwardBatchOutput):
generated_tokens = torch.empty(0, device=self.device, dtype=torch.int32)
for i in range(forward_output.num_batchs):
logit = forward_output.logits[i]
if hasattr(forward_output, "temperatures"):
temperatures = forward_output.temperatures[i]
else:
temperatures = None
if hasattr(forward_output, "top_ps"):
top_ps = forward_output.top_ps[i]
else:
top_ps = None
sample_options = SamplingOptions(logit.size(0), self.device, pretrained_config=self.model.generation_config, temperatures=temperatures, top_ps=top_ps)
generated_tokens, probs=self.sampler(logit, sample_options)
return generated_tokens, probs
def loop(self):
next_batch = None
while True:
self.batch = next_batch
if self.batch is not None:
self.model_runner.run(self.batch, self.query_manager)
if len(self.updates) > 0:
for q in self.updates:
if q.is_prefill == True:
continue
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
try:
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
except queue.Full:
pass#print("Queue is full after timeout; unable to put more items.")
next_batch = self.sched_client.update_last_batch(self.updates)
if next_batch.query_ids == []:
next_batch = None
self.pub_socket.send_pyobj(next_batch)
if next_batch is not None:
self.query_manager.add_query(next_batch)
if self.batch is not None:
self.model_runner.sync()
print(f"Model execution time (GPU): {self.model_runner.model_time:.3f} ms")
# if self.rank == 0:
generated_tokens, probs = self.sampling( self.model_runner.output)
self.updates = self.query_manager.update(self.batch)
fill_generated_tokens(self.updates, generated_tokens, self.query_manager)
else:
self.updates = []
class BalanceServeThreadContext(ThreadContext):
def get_local_messages(self):
local_messages = []
for m in self.messages:
local_messages.append({"role": m.role.value, "content": m.get_text_content()})
return local_messages
def run_engine(args, token_queue, broadcast_endpoint, event):
engine = Engine(args, token_queue, broadcast_endpoint)
if args.use_cuda_graph:
engine.model_runner.warmup()
event.set()
engine.loop()
class BalanceServeInterface(BackendInterfaceBase):
use_static_cache: bool = True
model: Any
tokenizer: AutoTokenizer
cache: StaticCache
generated_ids: torch.Tensor
seq_length: int
streamer: TextStreamer
# thread_related
last_request_id: Optional[str] = None
ever_generated_ids: Set[int] = set()
def __init__(self, args: ConfigArgs = default_args):
self.args = args
self.queue_map:dict[int,asyncio.Queue] = {}
self.thread_map: dict[int, int] = {}
processes = []
self.broadcast_endpoint = tempfile.NamedTemporaryFile(delete=False).name # @TODO add to config
ctx = mp.get_context("spawn")
self.token_queue = ctx.Queue(maxsize=1000)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
self.sched_client = SchedulerClient(args.sched_port)
self.streamer = TextStreamer(self.tokenizer)
start_event = ctx.Event()
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event))
p.start()
processes.append(p)
start_event.wait()
def run_queue_proxy(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.queue_proxy())
@asynccontextmanager
async def lifespan(self, app: FastAPI):
asyncio.create_task(self.queue_proxy())
yield
async def queue_proxy(self):
print("Queue Proxy Started")
while True:
try:
query_id, token = self.token_queue.get_nowait()
try:
# query id might not be allocated yet
self.queue_map[query_id].put_nowait(token)
#print(f"Proxy Put token: {token} to queue for query id: {query_id}")
except asyncio.QueueFull:
#print(f"Queue for query id: {query_id} is full, waiting to put: {token}")
await self.queue_map[query_id].put(token)
except queue.Empty:
# print("no new token")
# await asyncio.sleep(1)
await asyncio.sleep(0)
def tokenize_prompt(self, prompt: str):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.args.device)
return input_ids
def format_and_tokenize_input_ids(self, thread_id: ObjectID, messages: List):
for m in messages:
if m["role"] == "system":
logger.warning(f'change {m["role"]} to user')
m["role"] = "user"
new_messages = [messages[0]]
for m in messages[1:]:
if m["role"] == "user" and new_messages[-1]["role"] == "user":
logger.warning("merge two adjacent user messages")
new_messages[-1]["content"] += '\n' + m["content"]
else:
new_messages.append(m)
input_str: str = self.tokenizer.apply_chat_template(new_messages,tokenize=False,add_generation_prompt=True)
# drop <think> token in chat template
if input_str.endswith('<think>\n'):
input_str = input_str[:-len('<think>\n')]
input_ids = self.tokenizer.encode(input_str, return_tensors="pt").to(self.args.device)
logger.debug(f"get input ids of shape {input_ids.shape}")
return input_ids
async def inference(self, local_messages, thread_id: str, temperature: Optional[float] = None, top_p: Optional[float] = None):
profiler = Profiler()
profiler.create_and_start_timer("tokenize")
if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages)
else:
raise ValueError("local_messages should be List or str")
if Config().user_force_think:
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
input_ids = torch.cat(
[input_ids, token_thinks], dim=1
)
profiler.pause_timer("tokenize")
profiler.create_and_start_timer("prefill")
query_add = sched_ext.QueryAdd()
query_add.query_token = input_ids[0].tolist()
query_length = input_ids[0].shape[0]
query_add.query_length = query_length
profiler.set_counter("prefill", query_length)
#@TODO add server
stop_criteria = [self.tokenizer.encode(self.tokenizer.eos_token, add_special_tokens=False),self.tokenizer.encode("<|im_end|>")]
query_add.stop_criteria = stop_criteria
query_add.sample_options.temperature = temperature
query_add.sample_options.top_p = top_p
query_add.estimated_length = min(self.args.cache_lens, query_length+self.args.max_new_tokens)
query_id = self.sched_client.add_query(query_add)
queue = asyncio.Queue(maxsize=self.args.max_new_tokens)
self.queue_map[query_id] = queue
self.thread_map[thread_id] = query_id
is_first_token = True
async for token in chat_stream(self.queue_map[query_id], self.tokenizer):
if is_first_token:
is_first_token=False
profiler.pause_timer("prefill")
profiler.create_and_start_timer("decode")
profiler.set_counter("decode", 0)
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think, None
else:
profiler.inc("decode")
yield token, None
profiler.pause_timer("decode")
report_last_time_performance(profiler)
yield self.streamer.end(), None
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
yield "", "length"
else:
yield "", "stop"
yield RawUsage(
tokenize_time = profiler.get_timer_sec('tokenize'),
prefill_time = profiler.get_timer_sec('prefill'),
decode_time = profiler.get_timer_sec('decode'),
prefill_count = profiler.get_counter('prefill'),
decode_count = profiler.get_counter('decode'),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment