Unverified Commit 22804eaa authored by blkmjsian's avatar blkmjsian Committed by GitHub
Browse files

[T2-3-1]blkmjsian

- deepseek
- jiuge 4B awq 
parent 5c6000ec
#ifndef INFINICORE_INFER_H
#define INFINICORE_INFER_H
#include "infinicore_infer/models/jiuge.h"
#include "infinicore_infer/cache.h"
#include "infinicore_infer/weights_loader.h"
#include "infinicore_infer/models/deepseek.h"
#include "infinicore_infer/models/jiuge.h"
#endif /* INFINICORE_INFER_H */
#ifndef CACHE_H
#define CACHE_H
#include <infinirt.h>
__C __export struct KVCache *createKVCache(
size_t nlayers,
size_t max_len,
size_t nkvh_,
size_t dk,
size_t dv,
infiniDtype_t dtype,
infiniDevice_t device,
int *dev_ids,
size_t ndev);
__C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len);
__C __export void dropKVCache(KVCache *kv_cache);
#endif /* CACHE_H */
#ifndef DEEPSEEK_V3_WEIGHTS_H
#define DEEPSEEK_V3_WEIGHTS_H
#include <infiniccl.h>
#include <infiniop.h>
#include <infinirt.h>
#include <stddef.h>
#include <stdint.h>
struct DeepSeekV3Weights;
// Function pointer signatures
typedef void (*load_global_fn)(DeepSeekV3Weights *, void *cpu_ptr);
typedef void (*load_layer_fn)(DeepSeekV3Weights *, void *cpu_ptr, size_t layer_id);
typedef void (*load_layer_linear_fn)(DeepSeekV3Weights *, void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer_id);
typedef void (*load_layer_mlp_fn)(
DeepSeekV3Weights *,
void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
size_t layer_id);
typedef void (*load_layer_expert_mlp_fn)(
DeepSeekV3Weights *,
void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
size_t layer_id, size_t expert_id);
// Struct containing all weight loading functions
typedef struct {
// Global
load_global_fn load_input_embd;
load_global_fn load_output_norm;
load_global_fn load_output_embd;
// Attention
load_layer_fn load_attn_norm;
load_layer_linear_fn load_attn_q_a_proj;
load_layer_fn load_attn_q_a_layernorm;
load_layer_linear_fn load_attn_q_b_proj;
load_layer_linear_fn load_attn_kv_a_proj_with_mqa;
load_layer_fn load_attn_kv_a_layernorm;
load_layer_linear_fn load_attn_kv_b_proj;
load_layer_linear_fn load_attn_o_proj;
// MLP
load_layer_fn load_mlp_norm;
// MLP dense part
load_layer_mlp_fn load_mlp_dense;
// MLP sparse gating
load_layer_fn load_mlp_gate_weight;
load_layer_fn load_mlp_gate_bias;
// Shared experts
load_layer_mlp_fn load_mlp_shared_experts;
// Per-expert functions
load_layer_expert_mlp_fn load_mlp_experts;
} DeepSeekV3WeightLoader;
struct DeepSeekV3Model;
typedef struct {
infiniDtype_t dt_logits;
infiniDtype_t dt_norm;
infiniDtype_t dt_quant_weight;
infiniDtype_t dt_quant_scale;
infiniDtype_t dt_quant_zero;
infiniDtype_t dt_gate_weight;
infiniDtype_t dt_gate_bias;
size_t n_sparse_layer;
size_t n_dense_layer;
size_t d;
size_t nh;
size_t nkvh;
size_t d_rope;
size_t d_nope;
size_t r_q;
size_t r_kv;
size_t d_qk;
size_t d_v;
float routed_scale;
size_t nexperts;
size_t kexperts;
size_t di;
size_t di_moe;
size_t dctx;
size_t dvoc;
float epsilon;
float rope_theta;
uint32_t end_token;
} DeepSeekV3Meta;
//////////////////// APIs ///////////////////////
/// @brief 创建模型
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct DeepSeekV3Model *
createDeepSeekV3Model(const DeepSeekV3Meta *,
const DeepSeekV3Weights *);
__C DeepSeekV3Weights *
createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device,
int ndev,
const int *dev_ids);
__C __export DeepSeekV3WeightLoader *
createDeepSeekV3WeightLoader();
/// @brief 销毁模型
__C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *);
__C __export struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *);
__C __export void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *,
struct DeepSeekV3Cache *);
/// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct DeepSeekV3Cache **caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct DeepSeekV3Cache **caches,
void *logits);
#endif // DEEPSEEK_V3_WEIGHTS_H
\ No newline at end of file
......@@ -61,20 +61,6 @@ createJiugeModel(const JiugeMeta *,
__C __export void
destroyJiugeModel(struct JiugeModel *);
/// @brief 创建 KV Cache
__C __export struct KVCache *
createKVCache(const struct JiugeModel *);
/// @brief 复制 KV Cache
__C __export struct KVCache *
duplicateKVCache(const struct JiugeModel *,
const struct KVCache *, uint32_t seq_len);
/// @brief 销毁 KV Cache
__C __export void
dropKVCache(const struct JiugeModel *,
struct KVCache *);
/// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
......@@ -87,12 +73,12 @@ dropKVCache(const struct JiugeModel *,
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
inferBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
......@@ -103,10 +89,10 @@ inferBatch(struct JiugeModel *,
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits);
forwardBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits);
#endif
#ifndef MODEL_JIUGE_AWQ_H
#define MODEL_JIUGE_AWQ_H
#include <infiniccl.h>
#include <infiniop.h>
#include <infinirt.h>
#include <stdint.h>
#include "../weights_loader.h"
struct JiugeAWQModel;
typedef struct
{
infiniDtype_t dt_logits;
infiniDtype_t dt_linear_w;
infiniDtype_t dt_norm_w;
size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc;
float epsilon, theta;
uint32_t end_token;
size_t nbit;
size_t quant_group_size;
char has_qkv_bias;
} JiugeAWQMeta;
//////////////////// APIs ///////////////////////
__C __export struct ModelWeights *
createJiugeAWQWeights(const JiugeAWQMeta *,
infiniDevice_t device,
int ndev,
const int *dev_ids);
/// @brief 创建模型
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct JiugeAWQModel *
createJiugeAWQModel(const JiugeAWQMeta *,
const ModelWeights *);
/// @brief 销毁模型
__C __export void
destroyJiugeAWQModel(struct JiugeAWQModel *);
/// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits);
#endif
#ifndef WEIGHTS_LOADER_H
#define WEIGHTS_LOADER_H
#include <infinirt.h>
struct ModelWeights;
__C __export void
loadModelWeight(struct ModelWeights *weights, const char *name, void *data);
__C __export void
loadModelWeightDistributed(struct ModelWeights *weights, const char *name, void *data, int *ranks, int nrank);
#endif // WEIGHTS_LOADER_H
import ctypes
from typing import List, Sequence
from tqdm import tqdm
from libinfinicore_infer import (
DeepSeekV3MetaCStruct,
DeepSeekV3CacheCStruct,
DataType,
DeviceType,
create_deepseek_v3_model,
create_deepseek_v3_weights,
create_deepseek_v3_weight_loader,
destroy_deepseek_v3_model,
create_deepseek_v3_cache,
drop_deepseek_v3_cache,
infer_batch_deepseek_v3,
)
from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import math
import torch
import transformers
torch.set_default_device("cpu")
class DeepseekR1WeightsNaming:
def __init__(self, dense_replace=3):
self.dense_replace = dense_replace
def input_embd(self):
return "model.embed_tokens.weight"
def output_norm(self):
return "model.norm.weight"
def output_embd(self):
return "lm_head.weight"
# MLA
def attn_norm(self, i):
return f"model.layers.{i}.input_layernorm.weight"
def attn_kv_a_layernorm(self, i):
return f"model.layers.{i}.self_attn.kv_a_layernorm.weight"
def attn_kv_a_proj_with_mqa_weight(self, i):
return f"model.layers.{i}.self_attn.kv_a_proj_with_mqa.qweight"
def attn_kv_a_proj_with_mqa_scale(self, i):
return f"model.layers.{i}.self_attn.kv_a_proj_with_mqa.scales"
def attn_kv_a_proj_with_mqa_zero(self, i):
return f"model.layers.{i}.self_attn.kv_a_proj_with_mqa.qzeros"
def attn_kv_b_proj_weight(self, i):
return f"model.layers.{i}.self_attn.kv_b_proj.qweight"
def attn_kv_b_proj_scale(self, i):
return f"model.layers.{i}.self_attn.kv_b_proj.scales"
def attn_kv_b_proj_zero(self, i):
return f"model.layers.{i}.self_attn.kv_b_proj.qzeros"
def attn_o_proj_weight(self, i):
return f"model.layers.{i}.self_attn.o_proj.qweight"
def attn_o_proj_scale(self, i):
return f"model.layers.{i}.self_attn.o_proj.scales"
def attn_o_proj_zero(self, i):
return f"model.layers.{i}.self_attn.o_proj.qzeros"
def attn_q_a_layernorm(self, i):
return f"model.layers.{i}.self_attn.q_a_layernorm.weight"
def attn_q_a_proj_weight(self, i):
return f"model.layers.{i}.self_attn.q_a_proj.qweight"
def attn_q_a_proj_scale(self, i):
return f"model.layers.{i}.self_attn.q_a_proj.scales"
def attn_q_a_proj_zero(self, i):
return f"model.layers.{i}.self_attn.q_a_proj.qzeros"
def attn_q_b_proj_weight(self, i):
return f"model.layers.{i}.self_attn.q_b_proj.qweight"
def attn_q_b_proj_scale(self, i):
return f"model.layers.{i}.self_attn.q_b_proj.scales"
def attn_q_b_proj_zero(self, i):
return f"model.layers.{i}.self_attn.q_b_proj.qzeros"
# MLP
def mlp_norm(self, i):
return f"model.layers.{i}.post_attention_layernorm.weight"
# First self.dense_replace layers are dense
def mlp_down_proj_weight(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.down_proj.qweight"
def mlp_down_proj_scale(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.down_proj.scales"
def mlp_down_proj_zero(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.down_proj.qzeros"
def mlp_up_proj_weight(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.up_proj.qweight"
def mlp_up_proj_scale(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.up_proj.scales"
def mlp_up_proj_zero(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.up_proj.qzeros"
def mlp_gate_proj_weight(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.gate_proj.qweight"
def mlp_gate_proj_scale(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.gate_proj.scales"
def mlp_gate_proj_zero(self, i):
assert i < self.dense_replace
return f"model.layers.{i}.mlp.gate_proj.qzeros"
# Latter layers are sparse
# Gating
def mlp_gate_weight(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.gate.weight"
def mlp_gate_bias(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.gate.e_score_correction_bias"
# Experts
def mlp_shared_experts_down_proj_weight(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.down_proj.qweight"
def mlp_shared_experts_down_proj_scale(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.down_proj.scales"
def mlp_shared_experts_down_proj_zero(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.down_proj.qzeros"
def mlp_shared_experts_gate_proj_weight(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.gate_proj.qweight"
def mlp_shared_experts_gate_proj_scale(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.gate_proj.scales"
def mlp_shared_experts_gate_proj_zero(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.gate_proj.qzeros"
def mlp_shared_experts_up_proj_weight(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.up_proj.qweight"
def mlp_shared_experts_up_proj_scale(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.up_proj.scales"
def mlp_shared_experts_up_proj_zero(self, i):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.shared_experts.up_proj.qzeros"
# Experts
def mlp_experts_down_proj_weight(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.down_proj.qweight"
def mlp_experts_down_proj_scale(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.down_proj.scales"
def mlp_experts_down_proj_zero(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.down_proj.qzeros"
def mlp_experts_gate_proj_weight(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.gate_proj.qweight"
def mlp_experts_gate_proj_scale(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.gate_proj.scales"
def mlp_experts_gate_proj_zero(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.gate_proj.qzeros"
def mlp_experts_up_proj_weight(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.up_proj.qweight"
def mlp_experts_up_proj_scale(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.up_proj.scales"
def mlp_experts_up_proj_zero(self, i, e):
assert i >= self.dense_replace
return f"model.layers.{i}.mlp.experts.{e}.up_proj.qzeros"
class DeepSeekV3Meta(DeepSeekV3MetaCStruct):
def __init__(self, config, dtype=torch.float16, max_tokens=None):
if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.bfloat16:
dt_ = DataType.INFINI_DTYPE_BF16
else:
dt_ = DataType.INFINI_DTYPE_F16
super().__init__(
# dtypes
dt_logits=DataType.INFINI_DTYPE_F16,
dt_norm=DataType.INFINI_DTYPE_BF16,
dt_quant_weight=DataType.INFINI_DTYPE_I32,
dt_quant_scale=DataType.INFINI_DTYPE_F16,
dt_quant_zero=DataType.INFINI_DTYPE_I32,
dt_gate_weight=DataType.INFINI_DTYPE_BF16,
dt_gate_bias=DataType.INFINI_DTYPE_BF16,
# sizes
n_sparse_layer=config["num_hidden_layers"],
n_dense_layer=config.get("first_k_dense_replace", 0),
d=config["hidden_size"],
nh=config["num_attention_heads"],
nkvh=config.get("num_key_value_heads", config["num_attention_heads"]),
d_rope=config["qk_rope_head_dim"],
d_nope=config["qk_nope_head_dim"],
r_q=config["q_lora_rank"],
r_kv=config["kv_lora_rank"],
d_qk=config["qk_nope_head_dim"] + config["qk_rope_head_dim"],
d_v=config["v_head_dim"],
# routing / experts / vocab / ctx
routed_scale=config.get("routed_scaling_factor", 1.0),
nexperts=config["n_routed_experts"],
kexperts=config["num_experts_per_tok"],
di=config["intermediate_size"],
di_moe=config["moe_intermediate_size"],
dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
),
dvoc=config["vocab_size"],
# misc
epsilon=config.get("rms_norm_eps", 1e-6),
rope_theta=config.get("rope_theta", 10000.0),
end_token=config.get("eos_token_id", 2),
)
self.torch_dtype_logits = dtype
def load_specific_tensor(model_dir, tensor_name):
"""
Load a specific tensor from a sharded safetensors model using its index JSON.
"""
index_file = os.path.join(model_dir, "model.safetensors.index.json")
if not os.path.exists(index_file):
raise FileNotFoundError(f"Index file not found: {index_file}")
with open(index_file, "r") as f:
index = json.load(f)
# Get mapping: tensor name -> file name
weight_map = index["weight_map"]
if tensor_name not in weight_map:
raise KeyError(f"{tensor_name} not found in index")
filename = weight_map[tensor_name]
tensor_file = os.path.join(model_dir, filename)
# Open only the relevant file and tensor
with safetensors.safe_open(tensor_file, framework="pt", device="cpu") as f:
tensor = f.get_tensor(tensor_name)
return tensor
def load_deepseek_weights(
meta: DeepSeekV3Meta,
weights,
model_path: str,
ndev: int,
):
weight_loader = create_deepseek_v3_weight_loader()
names = DeepseekR1WeightsNaming()
input_embd = load_specific_tensor(model_path, names.input_embd()).to(meta.torch_dtype_logits)
weight_loader.contents.load_input_embd(weights, input_embd.data_ptr())
del input_embd
output_norm = load_specific_tensor(model_path, names.output_norm())
weight_loader.contents.load_output_norm(weights, output_norm.data_ptr())
del output_norm
output_embd = load_specific_tensor(model_path, names.output_embd())
weight_loader.contents.load_output_embd(weights, output_embd.data_ptr())
del output_embd
# -------------------------------
# Per-layer weights
# -------------------------------
def load_quant(w_name, s_name, zero_name, split_dim=0):
weight = load_specific_tensor(model_path, w_name)
scale = load_specific_tensor(model_path, s_name)
zero = load_specific_tensor(model_path, zero_name)
if split_dim == 0 or ndev == 1:
return weight, scale, zero
elif split_dim == 1:
weight = (
weight.reshape(weight.shape[0], ndev, -1).permute(1, 0, 2).contiguous()
)
scale = (
scale.reshape(scale.shape[0], ndev, -1).permute(1, 0, 2).contiguous()
)
zero = zero.reshape(zero.shape[0], ndev, -1).permute(1, 0, 2).contiguous()
return weight, scale, zero
else:
raise ValueError("split_dim must be 0 or 1")
for i in tqdm(
range(meta.n_sparse_layer + meta.n_dense_layer), desc="Loading layers"
):
# Attention norms + projections
attn_norm = load_specific_tensor(model_path, names.attn_norm(i))
weight_loader.contents.load_attn_norm(weights, attn_norm.data_ptr(), i)
del attn_norm
load_attn_q_a_layernorm = load_specific_tensor(
model_path, names.attn_q_a_layernorm(i)
)
weight_loader.contents.load_attn_q_a_layernorm(
weights, load_attn_q_a_layernorm.data_ptr(), i
)
del load_attn_q_a_layernorm
attn_kv_a_layernorm = load_specific_tensor(
model_path, names.attn_kv_a_layernorm(i)
)
weight_loader.contents.load_attn_kv_a_layernorm(
weights, attn_kv_a_layernorm.data_ptr(), i
)
del attn_kv_a_layernorm
w, s, z = load_quant(
names.attn_q_a_proj_weight(i),
names.attn_q_a_proj_scale(i),
names.attn_q_a_proj_zero(i),
)
weight_loader.contents.load_attn_q_a_proj(
weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i
)
w, s, z = load_quant(
names.attn_q_b_proj_weight(i),
names.attn_q_b_proj_scale(i),
names.attn_q_b_proj_zero(i),
)
weight_loader.contents.load_attn_q_b_proj(
weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i
)
w, s, z = load_quant(
names.attn_kv_a_proj_with_mqa_weight(i),
names.attn_kv_a_proj_with_mqa_scale(i),
names.attn_kv_a_proj_with_mqa_zero(i),
)
weight_loader.contents.load_attn_kv_a_proj_with_mqa(
weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i
)
w, s, z = load_quant(
names.attn_kv_b_proj_weight(i),
names.attn_kv_b_proj_scale(i),
names.attn_kv_b_proj_zero(i),
)
weight_loader.contents.load_attn_kv_b_proj(
weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i
)
w, s, z = load_quant(
names.attn_o_proj_weight(i),
names.attn_o_proj_scale(i),
names.attn_o_proj_zero(i),
1,
)
weight_loader.contents.load_attn_o_proj(
weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i
)
# -------------------------------
# MLP: dense or sparse
# -------------------------------
mlp_norm = load_specific_tensor(model_path, names.mlp_norm(i))
weight_loader.contents.load_mlp_norm(weights, mlp_norm.data_ptr(), i)
if i < meta.n_dense_layer:
# Dense MLP is grouped into one call
w_gate, s_gate, z_gate = load_quant(
names.mlp_gate_proj_weight(i),
names.mlp_gate_proj_scale(i),
names.mlp_gate_proj_zero(i),
)
w_up, s_up, z_up = load_quant(
names.mlp_up_proj_weight(i),
names.mlp_up_proj_scale(i),
names.mlp_up_proj_zero(i),
)
w_down, s_down, z_down = load_quant(
names.mlp_down_proj_weight(i),
names.mlp_down_proj_scale(i),
names.mlp_down_proj_zero(i),
1,
)
weight_loader.contents.load_mlp_dense(
weights,
w_gate.data_ptr(),
s_gate.data_ptr(),
z_gate.data_ptr(),
w_up.data_ptr(),
s_up.data_ptr(),
z_up.data_ptr(),
w_down.data_ptr(),
s_down.data_ptr(),
z_down.data_ptr(),
i,
)
else:
# Sparse MLP gating
mlp_gate_weight = load_specific_tensor(model_path, names.mlp_gate_weight(i))
weight_loader.contents.load_mlp_gate_weight(
weights, mlp_gate_weight.data_ptr(), i
)
del mlp_gate_weight
mlp_gate_bias = load_specific_tensor(model_path, names.mlp_gate_bias(i))
weight_loader.contents.load_mlp_gate_bias(
weights, mlp_gate_bias.data_ptr(), i
)
del mlp_gate_bias
# Shared experts
w_gate, s_gate, z_gate = load_quant(
names.mlp_shared_experts_gate_proj_weight(i),
names.mlp_shared_experts_gate_proj_scale(i),
names.mlp_shared_experts_gate_proj_zero(i),
)
w_up, s_up, z_up = load_quant(
names.mlp_shared_experts_up_proj_weight(i),
names.mlp_shared_experts_up_proj_scale(i),
names.mlp_shared_experts_up_proj_zero(i),
)
w_down, s_down, z_down = load_quant(
names.mlp_shared_experts_down_proj_weight(i),
names.mlp_shared_experts_down_proj_scale(i),
names.mlp_shared_experts_down_proj_zero(i),
1,
)
weight_loader.contents.load_mlp_shared_experts(
weights,
w_gate.data_ptr(),
s_gate.data_ptr(),
z_gate.data_ptr(),
w_up.data_ptr(),
s_up.data_ptr(),
z_up.data_ptr(),
w_down.data_ptr(),
s_down.data_ptr(),
z_down.data_ptr(),
i,
)
# Per-expert MLP
for e in range(meta.nexperts):
w_gate, s_gate, z_gate = load_quant(
names.mlp_experts_gate_proj_weight(i, e),
names.mlp_experts_gate_proj_scale(i, e),
names.mlp_experts_gate_proj_zero(i, e),
)
w_up, s_up, z_up = load_quant(
names.mlp_experts_up_proj_weight(i, e),
names.mlp_experts_up_proj_scale(i, e),
names.mlp_experts_up_proj_zero(i, e),
)
w_down, s_down, z_down = load_quant(
names.mlp_experts_down_proj_weight(i, e),
names.mlp_experts_down_proj_scale(i, e),
names.mlp_experts_down_proj_zero(i, e),
1,
)
weight_loader.contents.load_mlp_experts(
weights,
w_gate.data_ptr(),
s_gate.data_ptr(),
z_gate.data_ptr(),
w_up.data_ptr(),
s_up.data_ptr(),
z_up.data_ptr(),
w_down.data_ptr(),
s_down.data_ptr(),
z_down.data_ptr(),
i,
e,
)
class DeepSeekV3BatchedTask:
def __init__(self, tasks: List[InferTask]):
self.tasks = tasks
self.nreq = len(tasks)
# Precompute fields
token_lists = [t.tokens for t in tasks]
self.req_lens_list = [len(toks) for toks in token_lists]
self.req_pos_list = [t.pos for t in tasks]
self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
self.temperaturas_list = [t.temperature for t in tasks]
self.topks_list = [t.topk for t in tasks]
self.topps_list = [t.topp for t in tasks]
# Flatten token lists
flat_tokens = [tok for toks in token_lists for tok in toks]
self.ntok = len(flat_tokens)
# Convert to ctypes arrays in one pass
self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
self.kv_caches = (POINTER(DeepSeekV3CacheCStruct) * self.nreq)(
*self.kv_cache_ptrs
)
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list)
def input_args(self):
return (
self.tokens,
self.ntok,
self.req_lens,
self.nreq,
self.req_pos,
self.kv_caches,
self.temperaturas,
self.topks,
self.topps,
)
class DeepSeekV3ForCauslLM:
def __init__(
self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
):
with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f)
self.config = config
eos_token_id = self.config["eos_token_id"]
self.eos_token_id = (
[eos_token_id] if type(eos_token_id) == int else eos_token_id
)
print(model_dir_path)
if "deepseek_v3" == config["model_type"]:
self.meta = DeepSeekV3Meta(config, max_tokens=max_tokens, dtype=torch.float16)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
else:
raise ValueError("Unsupported model architecture")
print(f"Creating model on {ndev} devices...")
load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
weights = create_deepseek_v3_weights(
self.meta,
device,
ndev,
dev_ids,
)
# Load weights from host
# load_deepseek_weights(self.meta, weights, model_dir_path, ndev)
# Create model instance
self.model_instance = create_deepseek_v3_model(
byref(self.meta),
weights,
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
def max_context_len(self):
return self.meta.dctx
def create_kv_cache(self):
return create_deepseek_v3_cache(self.model_instance)
def drop_kv_cache(self, kv_cache):
drop_deepseek_v3_cache(self.model_instance, kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))()
batch_inputs = DeepSeekV3BatchedTask(tasks)
infer_batch_deepseek_v3(
self.model_instance,
*(batch_inputs.input_args()),
output,
)
return list(output)
def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0):
input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True,
tokenize=False,
)
tokens = self.tokenizer.encode(input_content)
infer_task = InferTask(
0,
tokens,
self.max_context_len(),
temperature_,
topk_,
topp_,
self.eos_token_id,
)
infer_task.bind_kvcache(KVCache(self))
print(input_content, end="", flush=True)
steps = 0
total_time = 0
output_content = ""
for step_i in range(max_steps):
start_time = time.time()
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
output_str = (
self.tokenizer._tokenizer.id_to_token(output_tokens[0])
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
break
infer_task.next(output_tokens[0])
if step_i > 0:
total_time += end_time - start_time
print("\n")
avg_time = total_time * 1000 / (steps - 1)
print(f"Time per step: {avg_time:.3f}ms")
infer_task._kv_cache.drop(self)
return output_content, avg_time
# def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
# tasks = [
# InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id)
# for i in range(batch_size)
# ]
# kv_caches = [KVCache(self) for _ in range(batch_size)]
# nll = 0.0
# total_len = 0
# for i in range(0, len(test_sequences), batch_size):
# batch_id = 0
# true_tokens = []
# while batch_id < batch_size and batch_id + i < len(test_sequences):
# input_tokens = test_sequences[i + batch_id][:-1]
# true_tokens.extend(test_sequences[i + batch_id][1:])
# tasks[batch_id].tokens = input_tokens
# tasks[batch_id].bind_kvcache(kv_caches[batch_id])
# batch_id += 1
# batch_inputs = DeepSeekV3BatchedTask(tasks[:batch_id])
# logits = torch.zeros(
# (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
# )
# forward_batch_deepseek_v3(
# self.model_instance,
# batch_inputs.tokens,
# batch_inputs.ntok,
# batch_inputs.req_lens,
# batch_inputs.nreq,
# batch_inputs.req_pos,
# batch_inputs.kv_caches,
# logits.data_ptr(),
# )
# logits = logits.float()
# token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,]
# log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab)
# token_logprobs = log_probs[
# torch.arange(batch_inputs.ntok), token_ids
# ] # (ntok,)
# start = 0
# for l in batch_inputs.req_lens_list:
# nll += -token_logprobs[start : start + l].sum().item()
# start += l
# total_len += token_logprobs.numel()
# for task in tasks:
# task.release_kvcache()
# return math.exp(nll / total_len)
def destroy_model_instance(self):
destroy_deepseek_v3_model(self.model_instance)
print("Model destroyed")
def test():
if len(sys.argv) < 3:
print(
"Usage: python deepseek.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
model_path = sys.argv[2]
device_type = DeviceType.DEVICE_TYPE_CPU
if sys.argv[1] == "--cpu":
device_type = DeviceType.DEVICE_TYPE_CPU
elif sys.argv[1] == "--nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif sys.argv[1] == "--cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif sys.argv[1] == "--ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif sys.argv[1] == "--metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif sys.argv[1] == "--moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
elif sys.argv[1] == "--iluvatar":
device_type = DeviceType.DEVICE_TYPE_ILUVATAR
else:
print(
"Usage: python deepseek.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = DeepSeekV3ForCauslLM(model_path, device_type, ndev, max_tokens=1024)
model.generate("山东最高的山是?", 50)
model.destroy_model_instance()
if __name__ == "__main__":
test()
......@@ -11,8 +11,8 @@ from libinfinicore_infer import (
destroy_jiuge_model,
create_kv_cache,
drop_kv_cache,
infer_batch,
forward_batch,
infer_batch_jiuge,
forward_batch_jiuge,
)
from infer_task import InferTask, KVCache
......@@ -506,13 +506,15 @@ class JiugeForCauslLM:
print(f"Creating model on {ndev} devices...")
load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.ndev = ndev
self.device = device
self.model_instance = create_jiuge_model(
byref(self.meta),
byref(self.weights),
device,
ndev,
dev_ids,
self.dev_ids,
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
......@@ -521,15 +523,25 @@ class JiugeForCauslLM:
return self.meta.dctx
def create_kv_cache(self):
return create_kv_cache(self.model_instance)
return create_kv_cache(
self.meta.nlayer,
self.meta.dctx,
self.meta.nkvh,
self.meta.dh,
self.meta.dh,
self.meta.dt_logits,
self.device,
self.dev_ids,
self.ndev,
)
def drop_kv_cache(self, kv_cache):
drop_kv_cache(self.model_instance, kv_cache)
drop_kv_cache(kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))()
batch_inputs = JiugeBatchedTask(tasks)
infer_batch(
infer_batch_jiuge(
self.model_instance,
*(batch_inputs.input_args()),
output,
......@@ -609,7 +621,7 @@ class JiugeForCauslLM:
logits = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
)
forward_batch(
forward_batch_jiuge(
self.model_instance,
batch_inputs.tokens,
batch_inputs.ntok,
......
from typing import List, Sequence
from libinfinicore_infer import (
JiugeAWQMetaCStruct,
KVCacheCStruct,
DataType,
DeviceType,
load_model_weight,
load_model_weight_distributed,
create_jiuge_awq_weights,
create_jiuge_awq_model,
destroy_jiuge_awq_model,
create_kv_cache,
drop_kv_cache,
infer_batch_jiuge_awq,
forward_batch_jiuge_awq,
)
from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import math
import torch
import transformers
torch.set_default_device("cpu")
class JiugeAWQMetaFromConfig(JiugeAWQMetaCStruct):
def __init__(self, config, dtype=torch.float16, max_tokens=None):
if config["torch_dtype"] == "float16":
dt_ = DataType.INFINI_DTYPE_F16
elif config["torch_dtype"] == "float32":
dt_ = DataType.INFINI_DTYPE_F32
elif config["torch_dtype"] == "bfloat16":
dt_ = DataType.INFINI_DTYPE_BF16
else:
dt_ = DataType.INFINI_DTYPE_F16
self.scale_input = 1.0
self.scale_output = 1.0
self.scale_o = 1.0
self.scale_down = 1.0
if (
config["model_type"] in ["fm9g", "minicpm"]
and "scale_emb" in config
and "scale_depth" in config
and "dim_model_base" in config
):
self.scale_input = config["scale_emb"]
self.scale_output = config["hidden_size"] // config["dim_model_base"]
self.scale_o = config["scale_depth"] / math.sqrt(
config["num_hidden_layers"]
)
self.scale_down = config["scale_depth"] / math.sqrt(
config["num_hidden_layers"]
)
has_qkv_bias = (
1 if "attention_bias" in config and config["attention_bias"] else 0
)
if config["model_type"] in ["qwen2", "qwen3"]:
has_qkv_bias = 1
eos_token_id = (
config["eos_token_id"][0]
if type(config["eos_token_id"]) == list
else config["eos_token_id"]
)
super().__init__(
dt_logits=dt_,
dt_linear_w=DataType.INFINI_DTYPE_I32,
dt_norm_w=dt_,
nlayer=config["num_hidden_layers"],
d=config["hidden_size"],
nh=config["num_attention_heads"],
nkvh=(
config["num_key_value_heads"]
if "num_key_value_heads" in config
else config["num_attention_heads"]
),
dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"],
dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
),
dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
end_token=eos_token_id,
nbit=config["quantization_config"]["bits"],
quant_group_size=config["quantization_config"]["group_size"],
has_qkv_bias=has_qkv_bias,
)
self.torch_dtype_logits = dtype
class JiugeAWQBatchedTask:
def __init__(self, tasks: List[InferTask]):
self.tasks = tasks
self.nreq = len(tasks)
# Precompute fields
token_lists = [t.tokens for t in tasks]
self.req_lens_list = [len(toks) for toks in token_lists]
self.req_pos_list = [t.pos for t in tasks]
self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
self.temperaturas_list = [t.temperature for t in tasks]
self.topks_list = [t.topk for t in tasks]
self.topps_list = [t.topp for t in tasks]
# Flatten token lists
flat_tokens = [tok for toks in token_lists for tok in toks]
self.ntok = len(flat_tokens)
# Convert to ctypes arrays in one pass
self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list)
def input_args(self):
return (
self.tokens,
self.ntok,
self.req_lens,
self.nreq,
self.req_pos,
self.kv_caches,
self.temperaturas,
self.topks,
self.topps,
)
class JiugeAWQForCausalLM:
def __init__(
self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
):
load_start_time = time.time()
print(f"Creating model on {ndev} devices...")
with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f)
self.config = config
eos_token_id = self.config["eos_token_id"]
self.eos_token_id = (
[eos_token_id] if type(eos_token_id) == int else eos_token_id
)
self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.ndev = ndev
self.device = device
self.meta = JiugeAWQMetaFromConfig(config, max_tokens=max_tokens)
self.weights = create_jiuge_awq_weights(
self.meta,
self.device,
ndev,
self.dev_ids,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
load_start_time = time.time()
print("Loading model weights to host...")
self.load_all_safetensors_from_dir(os.path.join(model_dir_path))
self.model_instance = create_jiuge_awq_model(
self.meta,
self.weights,
device,
ndev,
self.dev_ids,
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
def load_all_safetensors_from_dir(self, dir_path_: str):
dir_path_ = Path(dir_path_)
for file in sorted(dir_path_.glob("*.safetensors")):
with safetensors.safe_open(file, framework="pt", device="cpu") as f:
for key in f.keys():
# print(key)
tensor = f.get_tensor(key)
if "proj" in key and "bias" not in key:
if "o_proj" in key or "down_proj" in key:
tensor = (
tensor.reshape(tensor.shape[0], self.ndev, -1)
.permute(1, 0, 2)
.contiguous()
)
if "o_proj.scales" in key:
tensor = tensor * self.meta.scale_o
elif "down_proj.scales" in key:
tensor = tensor * self.meta.scale_down
load_model_weight_distributed(
self.weights,
key,
tensor.data_ptr(),
self.dev_ids,
self.ndev,
)
else:
load_model_weight_distributed(
self.weights,
key,
tensor.data_ptr(),
self.dev_ids,
self.ndev,
)
else:
if "embed_tokens.weight" in key:
tensor = tensor * self.meta.scale_input
elif "lm_head.weight" in key:
tensor = tensor * self.meta.scale_output
load_model_weight(self.weights, key, tensor.data_ptr())
def max_context_len(self):
return self.meta.dctx
def create_kv_cache(self):
return create_kv_cache(
self.meta.nlayer,
self.meta.dctx,
self.meta.nkvh,
self.meta.dh,
self.meta.dh,
self.meta.dt_logits,
self.device,
self.dev_ids,
self.ndev,
)
def drop_kv_cache(self, kv_cache):
drop_kv_cache(kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))()
batch_inputs = JiugeAWQBatchedTask(tasks)
infer_batch_jiuge_awq(
self.model_instance,
*(batch_inputs.input_args()),
output,
)
return list(output)
def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0):
input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True,
tokenize=False,
)
print(input_content, end="", flush=True)
tokens = self.tokenizer.encode(input_content)
infer_task = InferTask(
0,
tokens,
self.max_context_len(),
temperature_,
topk_,
topp_,
self.eos_token_id,
)
infer_task.bind_kvcache(KVCache(self))
steps = 0
total_time = 0
output_content = ""
for step_i in range(max_steps):
start_time = time.time()
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
# output_str = (
# self.tokenizer._tokenizer.id_to_token(output_tokens[0])
# .replace("▁", " ")
# .replace("<0x0A>", "\n")
# )
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
break
infer_task.next(output_tokens[0])
if step_i > 0:
total_time += end_time - start_time
print("\n")
avg_time = total_time * 1000 / (steps - 1)
print(f"Time per step: {avg_time:.3f}ms")
infer_task._kv_cache.drop(self)
return output_content, avg_time
def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
tasks = [
InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id)
for i in range(batch_size)
]
kv_caches = [KVCache(self) for _ in range(batch_size)]
nll = 0.0
total_len = 0
for i in range(0, len(test_sequences), batch_size):
batch_id = 0
true_tokens = []
while batch_id < batch_size and batch_id + i < len(test_sequences):
input_tokens = test_sequences[i + batch_id][:-1]
true_tokens.extend(test_sequences[i + batch_id][1:])
tasks[batch_id].tokens = input_tokens
tasks[batch_id].bind_kvcache(kv_caches[batch_id])
batch_id += 1
batch_inputs = JiugeAWQBatchedTask(tasks[:batch_id])
logits = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
)
forward_batch_jiuge_awq(
self.model_instance,
batch_inputs.tokens,
batch_inputs.ntok,
batch_inputs.req_lens,
batch_inputs.nreq,
batch_inputs.req_pos,
batch_inputs.kv_caches,
logits.data_ptr(),
)
logits = logits.float()
token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab)
token_logprobs = log_probs[
torch.arange(batch_inputs.ntok), token_ids
] # (ntok,)
start = 0
for l in batch_inputs.req_lens_list:
nll += -token_logprobs[start : start + l].sum().item()
start += l
total_len += token_logprobs.numel()
for task in tasks:
task.release_kvcache()
return math.exp(nll / total_len)
def destroy_model_instance(self):
destroy_jiuge_awq_model(self.model_instance)
print("Model destroyed")
def test():
if len(sys.argv) < 3:
print(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
model_path = sys.argv[2]
device_type = DeviceType.DEVICE_TYPE_CPU
if sys.argv[1] == "--cpu":
device_type = DeviceType.DEVICE_TYPE_CPU
elif sys.argv[1] == "--nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif sys.argv[1] == "--cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif sys.argv[1] == "--ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif sys.argv[1] == "--metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif sys.argv[1] == "--moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
elif sys.argv[1] == "--iluvatar":
device_type = DeviceType.DEVICE_TYPE_ILUVATAR
else:
print(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeAWQForCausalLM(model_path, device_type, ndev)
model.generate("山东最高的山是?", 500)
model.destroy_model_instance()
if __name__ == "__main__":
test()
from jiuge import JiugeForCauslLM
from jiuge_awq import JiugeAWQForCausalLM
from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
......@@ -25,6 +26,7 @@ DEVICE_TYPE_MAP = {
"moore": DeviceType.DEVICE_TYPE_MOORE,
}
def parse_args():
parser = argparse.ArgumentParser(description="Launch the LLM inference server.")
parser.add_argument(
......@@ -58,19 +60,26 @@ def parse_args():
default=None,
help="Max token sequence length that model will handle (follows model config if not provided)",
)
parser.add_argument(
"--awq",
action="store_true",
help="Whether to use AWQ quantized model (default: False)",
)
return parser.parse_args()
args = parse_args()
device_type = DEVICE_TYPE_MAP[args.dev]
model_path = args.model_path
ndev = args.ndev
max_tokens = args.max_tokens
USE_AWQ = args.awq
MAX_BATCH = args.max_batch
print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
def chunk_json(id_, content=None, role=None, finish_reason=None):
delta = {}
if content:
......@@ -109,7 +118,14 @@ class AsyncInferTask(InferTask):
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
app.state.model = JiugeForCauslLM(model_path, device_type, ndev, max_tokens=max_tokens)
if USE_AWQ:
app.state.model = JiugeAWQForCausalLM(
model_path, device_type, ndev, max_tokens=max_tokens
)
else:
app.state.model = JiugeForCauslLM(
model_path, device_type, ndev, max_tokens=max_tokens
)
app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH)
app.state.request_queue = janus.Queue()
worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True)
......@@ -277,6 +293,7 @@ async def chat_completions(request: Request):
response = await chat(id_, data, request)
return JSONResponse(content=response)
if __name__ == "__main__":
uvicorn.run(App, host="0.0.0.0", port=8000)
......
import ctypes
from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER
from ctypes import c_char, c_char_p, c_size_t, c_uint, c_int, c_float, c_void_p, POINTER
import os
......@@ -77,15 +77,188 @@ class JiugeModelCSruct(ctypes.Structure):
pass
class DeepSeekV3MetaCStruct(ctypes.Structure):
_fields_ = [
# dtypes
("dt_logits", DataType),
("dt_norm", DataType),
("dt_quant_weight", DataType),
("dt_quant_scale", DataType),
("dt_quant_zero", DataType),
("dt_gate_weight", DataType),
("dt_gate_bias", DataType),
# sizes
("n_sparse_layer", c_size_t),
("n_dense_layer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("d_rope", c_size_t),
("d_nope", c_size_t),
("r_q", c_size_t),
("r_kv", c_size_t),
("d_qk", c_size_t),
("d_v", c_size_t),
# routing / experts / vocab / ctx
("routed_scale", c_float),
("nexperts", c_size_t),
("kexperts", c_size_t),
("di", c_size_t),
("di_moe", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
# misc
("epsilon", c_float),
("rope_theta", c_float),
("end_token", c_uint),
]
class DeepSeekV3WeightsCStruct(ctypes.Structure):
pass
# void (*load_global_fn)(DeepSeekV3Weights*, void *cpu_ptr)
load_global_fn = ctypes.CFUNCTYPE(None, POINTER(DeepSeekV3WeightsCStruct), c_void_p)
# void (*load_layer_fn)(DeepSeekV3Weights*, void *cpu_ptr, size_t layer_id)
load_layer_fn = ctypes.CFUNCTYPE(
None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_size_t
)
# void (*load_layer_linear_fn)(DeepSeekV3Weights*, void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer_id)
load_layer_linear_fn = ctypes.CFUNCTYPE(
None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_void_p, c_void_p, c_size_t
)
# void (*load_layer_mlp_fn)(DeepSeekV3Weights*, ... , size_t layer_id)
load_layer_mlp_fn = ctypes.CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
)
# void (*load_layer_expert_mlp_fn)(DeepSeekV3Weights*, ..., size_t layer_id, size_t expert_id)
load_layer_expert_mlp_fn = ctypes.CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
c_size_t,
)
# -------------------------------------------------------------------
# Struct containing all weight loading functions
# -------------------------------------------------------------------
class DeepSeekV3WeightLoaderCStruct(ctypes.Structure):
_fields_ = [
# Global
("load_input_embd", load_global_fn),
("load_output_norm", load_global_fn),
("load_output_embd", load_global_fn),
# Attention
("load_attn_norm", load_layer_fn),
("load_attn_q_a_proj", load_layer_linear_fn),
("load_attn_q_a_layernorm", load_layer_fn),
("load_attn_q_b_proj", load_layer_linear_fn),
("load_attn_kv_a_proj_with_mqa", load_layer_linear_fn),
("load_attn_kv_a_layernorm", load_layer_fn),
("load_attn_kv_b_proj", load_layer_linear_fn),
("load_attn_o_proj", load_layer_linear_fn),
# MLP
("load_mlp_norm", load_layer_fn),
# MLP dense part
("load_mlp_dense", load_layer_mlp_fn),
# MLP sparse gating
("load_mlp_gate_weight", load_layer_fn),
("load_mlp_gate_bias", load_layer_fn),
# Shared experts
("load_mlp_shared_experts", load_layer_mlp_fn),
# Per-expert functions
("load_mlp_experts", load_layer_expert_mlp_fn),
]
class DeepSeekV3ModelCStruct(ctypes.Structure):
pass
class KVCacheCStruct(ctypes.Structure):
pass
class DeepSeekV3CacheCStruct(ctypes.Structure):
pass
class JiugeAWQMetaCStruct(ctypes.Structure):
_fields_ = [
("dt_logits", DataType),
("dt_linear_w", DataType),
("dt_norm_w", DataType),
("nlayer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_uint),
("nbit", c_size_t),
("quant_group_size", c_size_t),
("has_qkv_bias", c_char),
]
class ModelWeightsCStruct(ctypes.Structure):
pass
class JiugeAWQModelCStruct(ctypes.Structure):
pass # opaque struct
def __open_library__():
lib_path = os.path.join(
os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
)
lib = ctypes.CDLL(lib_path)
lib.createKVCache.argtypes = [
c_size_t, # nlayers
c_size_t, # max_len
c_size_t, # nkvh_
c_size_t, # dk
c_size_t, # dv
DataType, # dtype
DeviceType, # device
POINTER(c_int), # dev_ids
c_size_t, # ndev
]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)]
lib.createJiugeModel.restype = POINTER(JiugeModelCSruct)
lib.createJiugeModel.argtypes = [
POINTER(JiugeMetaCStruct), # JiugeMeta const *
......@@ -95,11 +268,9 @@ def __open_library__():
POINTER(c_int), # int const *dev_ids
]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(JiugeModelCSruct), POINTER(KVCacheCStruct)]
lib.inferBatch.restype = None
lib.inferBatch.argtypes = [
lib.inferBatchJiuge.restype = None
lib.inferBatchJiuge.argtypes = [
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
......@@ -112,8 +283,8 @@ def __open_library__():
POINTER(c_float), # float topp
POINTER(c_uint), # unsigned int *output
]
lib.forwardBatch.restype = None
lib.forwardBatch.argtypes = [
lib.forwardBatchJiuge.restype = None
lib.forwardBatchJiuge.argtypes = [
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
......@@ -124,14 +295,164 @@ def __open_library__():
c_void_p, # void *logits
]
# createDeepSeekV3WeightLoader
lib.createDeepSeekV3WeightLoader.argtypes = []
lib.createDeepSeekV3WeightLoader.restype = POINTER(DeepSeekV3WeightLoaderCStruct)
lib.createDeepSeekV3Weights.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
DeviceType,
c_int,
POINTER(c_int),
]
lib.createDeepSeekV3Weights.restype = POINTER(DeepSeekV3WeightsCStruct)
lib.createDeepSeekV3Model.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
POINTER(DeepSeekV3WeightsCStruct),
]
lib.createDeepSeekV3Model.restype = POINTER(DeepSeekV3ModelCStruct)
# destroyDeepSeekV3Model
lib.destroyDeepSeekV3Model.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.destroyDeepSeekV3Model.restype = None
# createDeepSeekV3Cache
lib.createDeepSeekV3Cache.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.createDeepSeekV3Cache.restype = POINTER(DeepSeekV3CacheCStruct)
# dropDeepSeekV3Cache
lib.dropDeepSeekV3Cache.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(DeepSeekV3CacheCStruct),
]
lib.dropDeepSeekV3Cache.restype = None
# inferBatchDeepSeekV3
lib.inferBatchDeepSeekV3.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(DeepSeekV3CacheCStruct)),
POINTER(c_float),
POINTER(c_uint),
POINTER(c_float),
POINTER(c_uint),
]
lib.inferBatchDeepSeekV3.restype = None
# forwardBatchDeepSeekV3
lib.forwardBatchDeepSeekV3.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(DeepSeekV3CacheCStruct)),
c_void_p,
]
lib.forwardBatchDeepSeekV3.restype = None
lib.createJiugeAWQWeights.restype = POINTER(ModelWeightsCStruct)
lib.createJiugeAWQWeights.argtypes = [
POINTER(JiugeAWQMetaCStruct), # const JiugeAWQMeta*
DeviceType, # infiniDevice_t
c_int, # int ndev
POINTER(c_int), # const int* dev_ids
]
# createJiugeAWQModel
lib.createJiugeAWQModel.restype = POINTER(JiugeAWQModelCStruct)
lib.createJiugeAWQModel.argtypes = [
POINTER(JiugeAWQMetaCStruct), # const JiugeAWQMeta*
POINTER(ModelWeightsCStruct), # const ModelWeights*
]
# destroyJiugeAWQModel
lib.destroyJiugeAWQModel.argtypes = [POINTER(JiugeAWQModelCStruct)]
lib.destroyJiugeAWQModel.restype = None
# inferBatchJiugeAWQ
lib.inferBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct), # JiugeAWQModel*
POINTER(c_uint), # const uint32_t* tokens
c_uint, # uint32_t ntok
POINTER(c_uint), # const uint32_t* req_lens
c_uint, # uint32_t nreq
POINTER(c_uint), # const uint32_t* req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache** kv_caches
POINTER(c_float), # const float* temperature
POINTER(c_uint), # const uint32_t* topk
POINTER(c_float), # const float* topp
POINTER(c_uint), # uint32_t* output
]
lib.inferBatchJiugeAWQ.restype = None
# forwardBatchJiugeAWQ
lib.forwardBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct), # JiugeAWQModel*
POINTER(c_uint), # const uint32_t* tokens
c_uint, # uint32_t ntok
POINTER(c_uint), # const uint32_t* req_lens
c_uint, # uint32_t nreq
POINTER(c_uint), # const uint32_t* req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache** kv_caches
c_void_p, # void* logits
]
lib.forwardBatchJiugeAWQ.restype = None
lib.loadModelWeight.argtypes = [
POINTER(ModelWeightsCStruct), # struct ModelWeights*
c_char_p, # const char* name
c_void_p, # void* data
]
lib.loadModelWeight.restype = None
# loadModelWeightDistributed
lib.loadModelWeightDistributed.argtypes = [
POINTER(ModelWeightsCStruct), # struct ModelWeights*
c_char_p, # const char* name
c_void_p, # void* data
POINTER(c_int), # int* ranks
c_int, # int nrank
]
lib.loadModelWeightDistributed.restype = None
return lib
LIB = __open_library__()
def load_model_weight(weights, name, data):
LIB.loadModelWeight(weights, name.encode("utf-8"), data)
def load_model_weight_distributed(weights, name, data, ranks, nrank):
LIB.loadModelWeightDistributed(weights, name.encode("utf-8"), data, ranks, nrank)
create_jiuge_model = LIB.createJiugeModel
destroy_jiuge_model = LIB.destroyJiugeModel
create_kv_cache = LIB.createKVCache
drop_kv_cache = LIB.dropKVCache
infer_batch = LIB.inferBatch
forward_batch = LIB.forwardBatch
infer_batch_jiuge = LIB.inferBatchJiuge
forward_batch_jiuge = LIB.forwardBatchJiuge
create_jiuge_awq_weights = LIB.createJiugeAWQWeights
create_jiuge_awq_model = LIB.createJiugeAWQModel
destroy_jiuge_awq_model = LIB.destroyJiugeAWQModel
infer_batch_jiuge_awq = LIB.inferBatchJiugeAWQ
forward_batch_jiuge_awq = LIB.forwardBatchJiugeAWQ
create_deepseek_v3_model = LIB.createDeepSeekV3Model
destroy_deepseek_v3_model = LIB.destroyDeepSeekV3Model
create_deepseek_v3_weight_loader = LIB.createDeepSeekV3WeightLoader
create_deepseek_v3_weights = LIB.createDeepSeekV3Weights
create_deepseek_v3_cache = LIB.createDeepSeekV3Cache
drop_deepseek_v3_cache = LIB.dropDeepSeekV3Cache
infer_batch_deepseek_v3 = LIB.inferBatchDeepSeekV3
#pragma once
#include "tensor.hpp"
#include <memory>
#include <vector>
struct KVCache {
std::vector<std::vector<std::shared_ptr<Tensor>>> k, v;
};
#include "jiuge_impl.hpp"
#include "../cache.hpp"
__C struct KVCache *createKVCache(
size_t nlayers,
size_t max_len,
size_t nkvh_,
size_t dk,
size_t dv,
infiniDtype_t dtype,
infiniDevice_t device,
int *dev_ids,
size_t ndev) {
__C struct KVCache *createKVCache(const JiugeModel *model) {
KVCache *cache = new KVCache();
auto ndev = model->dev_resources.size();
auto nkvh = model->meta.nkvh / ndev;
auto max_len = model->meta.dctx;
auto dh = model->meta.dh;
auto shape = std::vector<size_t>{max_len, nkvh, dh};
auto nkvh = nkvh_ / ndev;
auto shape_k = std::vector<size_t>{max_len, nkvh, dk};
auto shape_v = std::vector<size_t>{max_len, nkvh, dv};
for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
RUN_INFINI(infinirtSetDevice(device, dev_ids[idev]));
auto kcache = std::vector<std::shared_ptr<Tensor>>();
auto vcache = std::vector<std::shared_ptr<Tensor>>();
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
kcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape)));
vcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape)));
for (unsigned int layer = 0; layer < nlayers; layer++) {
kcache.push_back(std::move(Tensor::buffer(dtype, shape_k)));
vcache.push_back(std::move(Tensor::buffer(dtype, shape_v)));
}
cache->k.push_back(kcache);
cache->v.push_back(vcache);
......@@ -22,35 +31,47 @@ __C struct KVCache *createKVCache(const JiugeModel *model) {
return cache;
}
__C struct KVCache *duplicateKVCache(const JiugeModel *model,
const KVCache *kv_cache,
unsigned int seq_len) {
auto new_kv_cache = createKVCache(model);
auto ndev = model->dev_resources.size();
auto nkvh = model->meta.nkvh / ndev;
auto dh = model->meta.dh;
auto dt_size = dsize(model->meta.dt_logits);
__C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType();
auto dtype = kv_cache->k[0][0]->dtype();
auto shape_k = kv_cache->k[0][0]->shape();
auto shape_v = kv_cache->v[0][0]->shape();
auto size_k = seq_len * shape_k[1] * shape_k[2] * dsize(dtype);
auto size_v = seq_len * shape_v[1] * shape_v[2] * dsize(dtype);
KVCache *new_kv_cache = new KVCache();
for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
RUN_INFINI(infinirtSetDevice(device, kv_cache->k[idev][0]->deviceId()));
for (unsigned int layer = 0; layer < nlayers; layer++) {
auto kcache = std::vector<std::shared_ptr<Tensor>>();
auto vcache = std::vector<std::shared_ptr<Tensor>>();
for (unsigned int layer = 0; layer < nlayers; layer++) {
kcache.push_back(std::move(Tensor::buffer(dtype, shape_k)));
vcache.push_back(std::move(Tensor::buffer(dtype, shape_v)));
}
new_kv_cache->k.push_back(kcache);
new_kv_cache->v.push_back(vcache);
RUN_INFINI(infinirtMemcpy(new_kv_cache->k[idev][layer]->data(),
kv_cache->k[idev][layer]->data(),
seq_len * nkvh * dh * dt_size,
size_k,
INFINIRT_MEMCPY_D2D));
RUN_INFINI(infinirtMemcpy(new_kv_cache->v[idev][layer]->data(),
kv_cache->v[idev][layer]->data(),
seq_len * nkvh * dh * dt_size,
size_v,
INFINIRT_MEMCPY_D2D));
}
}
return new_kv_cache;
}
__C void dropKVCache(JiugeModel const *model, KVCache *kv_cache) {
auto ndev = model->dev_resources.size();
__C void dropKVCache(KVCache *kv_cache) {
auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType();
for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
RUN_INFINI(infinirtSetDevice(device, kv_cache->k[idev][0]->deviceId()));
for (unsigned int layer = 0; layer < nlayers; layer++) {
kv_cache->k[idev][layer].reset();
kv_cache->v[idev][layer].reset();
}
......
......@@ -145,10 +145,13 @@ private:
LRUDescriptorCache<infiniopRMSNormDescriptor_t> rms_norm_cache;
LRUDescriptorCache<infiniopGemmDescriptor_t> gemm_cache;
LRUDescriptorCache<infiniopRoPEDescriptor_t> rope_cache;
LRUDescriptorCache<infiniopRoPEv2Descriptor_t> rope_v2_cache;
LRUDescriptorCache<infiniopRearrangeDescriptor_t> rearrange_cache;
LRUDescriptorCache<infiniopCausalSoftmaxDescriptor_t> causal_softmax_cache;
LRUDescriptorCache<infiniopTopkrouterDescriptor_t> causal_topkrouter_cache;
LRUDescriptorCache<infiniopSwiGLUDescriptor_t> swiglu_cache;
LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache;
LRUDescriptorCache<infiniopDequantizeDescriptor_t> dequantize_cache;
public:
CacheManager(size_t capacity = 100)
......@@ -156,10 +159,13 @@ public:
rms_norm_cache(capacity, infiniopDestroyRMSNormDescriptor),
gemm_cache(capacity, infiniopDestroyGemmDescriptor),
rope_cache(capacity, infiniopDestroyRoPEDescriptor),
rope_v2_cache(capacity, infiniopDestroyRoPEv2Descriptor),
rearrange_cache(capacity, infiniopDestroyRearrangeDescriptor),
causal_softmax_cache(capacity, infiniopDestroyCausalSoftmaxDescriptor),
causal_topkrouter_cache(capacity, infiniopDestroyTopkrouterDescriptor),
swiglu_cache(capacity, infiniopDestroySwiGLUDescriptor),
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor) {}
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor),
dequantize_cache(capacity, infiniopDestroyDequantizeDescriptor) {}
// Add operations
bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) {
......@@ -197,6 +203,14 @@ public:
rope_cache.put(key, desc);
}
bool getRoPEv2Descriptor(size_t key, infiniopRoPEv2Descriptor_t &desc) {
return rope_v2_cache.get(key, desc);
}
void putRoPEv2Descriptor(size_t key, const infiniopRoPEv2Descriptor_t &desc) {
rope_v2_cache.put(key, desc);
}
// Rearrange operations
bool getRearrangeDescriptor(size_t key, infiniopRearrangeDescriptor_t &desc) {
return rearrange_cache.get(key, desc);
......@@ -215,6 +229,15 @@ public:
causal_softmax_cache.put(key, desc);
}
// Topkrouter operations
bool getTopkrouterDescriptor(size_t key, infiniopTopkrouterDescriptor_t &desc) {
return causal_topkrouter_cache.get(key, desc);
}
void putTopkrouterDescriptor(size_t key, const infiniopTopkrouterDescriptor_t &desc) {
causal_topkrouter_cache.put(key, desc);
}
// SwiGLU operations
bool getSwiGLUDescriptor(size_t key, infiniopSwiGLUDescriptor_t &desc) {
return swiglu_cache.get(key, desc);
......@@ -233,6 +256,15 @@ public:
random_sample_cache.put(key, desc);
}
// Dequantize operations
bool getDequantizeDescriptor(size_t key, infiniopDequantizeDescriptor_t &desc) {
return dequantize_cache.get(key, desc);
}
void putDequantizeDescriptor(size_t key, const infiniopDequantizeDescriptor_t &desc) {
dequantize_cache.put(key, desc);
}
template <typename... Tensors>
static size_t createDescriptorKey(Tensors... tensors) {
size_t seed = 0;
......
#include "weights_loader.hpp"
#include "infinicore_infer/weights_loader.h"
#include "../utils.hpp"
#include <infinirt.h>
namespace infinicore {
WeightsLoader::WeightsLoader(infiniDevice_t dev, const std::vector<int> &dev_ids) : _device(dev), _dev_ids(dev_ids) {
_streams.resize(_dev_ids.size());
_weights.resize(_dev_ids.size());
for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
_weights[rank] = std::unordered_map<std::string, std::shared_ptr<Tensor>>();
RUN_INFINI(infinirtStreamCreate(&_streams[rank]));
}
}
void WeightsLoader::resigter(const std::string &name, std::shared_ptr<Tensor> tensor, int rank) {
_weights[rank][name] = tensor;
}
void WeightsLoader::load_weight(const std::string &name, const void *host_data) {
for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
auto it = _weights[rank].find(name);
if (it == _weights[rank].end()) {
std::cerr << "Weight " << name << " not found in rank " << rank << std::endl;
std::abort();
}
_weights[rank][name]->load(host_data, _streams[rank]);
}
for (int rank = int(_dev_ids.size() - 1); rank >= 0; rank--) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
}
}
void WeightsLoader::load_distributed_weight(const std::string &name, const void *host_data, const std::vector<int> &ranks) {
for (size_t i = 0; i < ranks.size(); i++) {
int rank = ranks[i];
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
auto it = _weights[rank].find(name);
if (it == _weights[rank].end()) {
std::cerr << "Weight " << name << " not found in rank " << rank << std::endl;
std::abort();
}
_weights[rank][name]->load((char *)host_data + i * _weights[rank][name]->numel() * dsize(_weights[rank][name]->dtype()), _streams[rank]);
}
for (int rank = int(_dev_ids.size() - 1); rank >= 0; rank--) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
}
}
void WeightsLoader::load_rank_weight(const std::string &name, const void *host_data, int rank) {
auto it = _weights[rank].find(name);
if (it == _weights[rank].end()) {
std::cerr << "Weight " << name << " not found in rank " << rank << std::endl;
std::abort();
}
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
_weights[rank][name]->load(host_data);
}
void WeightsLoader::finalize() {
int dev_id;
RUN_INFINI(infinirtGetDevice(nullptr, &dev_id));
for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
RUN_INFINI(infinirtStreamDestroy(_streams[rank]));
}
RUN_INFINI(infinirtSetDevice(_device, dev_id));
}
std::shared_ptr<Tensor> WeightsLoader::get(const std::string &name, int rank) {
return _weights[rank][name];
}
} // namespace infinicore
__C void
loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) {
std::string name_str(name);
// std::cout << "Loading weight: " << name_str << std::endl;
auto weights = reinterpret_cast<infinicore::WeightsLoader *>(weights_);
weights->load_weight(name_str, data);
}
__C void
loadModelWeightDistributed(struct ModelWeights *weights_, const char *name, void *data, int *ranks, int nrank) {
std::string name_str(name);
// std::cout << "Loading dist weight: " << name_str << std::endl;
auto weights = reinterpret_cast<infinicore::WeightsLoader *>(weights_);
std::vector<int> rank_vec(ranks, ranks + nrank);
weights->load_distributed_weight(name_str, data, rank_vec);
}
#ifndef WEIGHTS_LOADER_HPP
#define WEIGHTS_LOADER_HPP
#include "../tensor.hpp"
#include <unordered_map>
#include <vector>
namespace infinicore {
class WeightsLoader {
protected:
std::vector<std::unordered_map<std::string, std::shared_ptr<Tensor>>> _weights;
infiniDevice_t _device;
std::vector<int> _dev_ids;
std::vector<infinirtStream_t> _streams;
public:
WeightsLoader(infiniDevice_t, const std::vector<int> &dev_ids);
void resigter(const std::string &name, std::shared_ptr<Tensor> tensor, int rank = 0);
void load_weight(const std::string &name, const void *host_data);
void load_distributed_weight(const std::string &name, const void *host_data, const std::vector<int> &ranks);
void load_rank_weight(const std::string &name, const void *host_data, int rank);
void finalize();
std::shared_ptr<Tensor> get(const std::string &name, int rank = 0);
const std::vector<int> &dev_ids() const { return _dev_ids; }
infiniDevice_t device() const { return _device; }
};
} // namespace infinicore
#endif // WEIGHTS_LOADER_HPP
#include "deepseek_v3_impl.hpp"
#include "../../tensor.hpp"
#include "../../utils.hpp"
#include "../inference_context.hpp"
#include "infinicore_infer.h"
#include <random>
#include <thread>
#include <vector>
void createDeviceResource(DeepSeekV3DeviceResource *rsrc, const DeepSeekV3Meta *meta,
std::shared_ptr<DeepSeekV3DeviceWeights> weights,
infiniDevice_t device, int idev,
int ndev, int dev_id,
infinicclComm_t comm) {
RUN_INFINI(infinirtSetDevice(device, dev_id));
RUN_INFINI(infinirtStreamSynchronize(weights->load_stream));
infiniopHandle_t handle;
infiniopCreateHandle(&handle);
infinirtStream_t stream;
infinirtStreamCreate(&stream);
auto memory_pool = std::make_shared<MemoryPool>();
*rsrc = DeepSeekV3DeviceResource{
device,
dev_id,
handle,
weights,
stream,
comm,
memory_pool,
};
RUN_INFINI(infinirtDeviceSynchronize());
}
void releaseDeviceResource(DeepSeekV3DeviceResource &res) {
infinirtDeviceSynchronize();
res.weights.reset();
infiniopDestroyHandle(res.handle);
res.handle = nullptr;
infinirtStreamDestroy(res.stream);
res.stream = nullptr;
infinicclCommDestroy(res.comm);
res.comm = nullptr;
}
void inferDeviceBatch(const DeepSeekV3Meta &meta, DeepSeekV3DeviceResource &rsrc,
uint32_t idev, uint32_t ndev,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct DeepSeekV3Cache **caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output, void *last_logits) {
auto dt_logits = meta.dt_logits;
// auto dt_norm = meta.dt_norm;
// auto dt_quant_weight = meta.dt_quant_weight;
// auto dt_quant_scale = meta.dt_quant_scale;
// auto dt_quant_zero = meta.dt_quant_zero;
// auto dt_gate_weight = meta.dt_gate_weight;
// auto dt_gate_bias = meta.dt_gate_bias;
auto n_dense_layer = meta.n_dense_layer;
auto n_sparse_layer = meta.n_sparse_layer;
auto nlayer = n_dense_layer + n_sparse_layer;
size_t nh = meta.nh / size_t(ndev);
auto d = meta.d;
auto d_rope = meta.d_rope;
auto d_nope = meta.d_nope;
auto r_q = meta.r_q;
auto r_kv = meta.r_kv;
auto d_qk = meta.d_qk;
auto d_v = meta.d_v;
// auto routed_scale = meta.routed_scale;
// auto nexperts = meta.nexperts;
// auto kexperts = meta.kexperts;
auto di = meta.di / size_t(ndev);
auto dvoc = meta.dvoc;
auto stream = rsrc.stream;
auto weights = rsrc.weights;
// Allocate buffers
auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool);
auto q_a_buf = Tensor::buffer(dt_logits, {ntok, r_q}, rsrc.memory_pool);
auto q_buf = Tensor::buffer(dt_logits, {ntok, nh * d_qk}, rsrc.memory_pool);
auto kv_a_buf = Tensor::buffer(dt_logits, {ntok, r_kv + d_rope}, rsrc.memory_pool);
auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * d_v}, rsrc.memory_pool);
auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool);
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool);
auto result_cpu = std::vector<int64_t>(nreq);
// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
size_t req_start = 0;
for (uint32_t req = 0; req < nreq; req++) {
for (uint32_t i = 0; i < req_lens[req]; i++) {
batch_pos_ids[req_start + i] = req_pos[req] + i;
}
req_start += req_lens[req];
}
std::shared_ptr<Tensor> pos_ids_buf;
if (rsrc.device == INFINI_DEVICE_CPU) {
pos_ids_buf = Tensor::weight(batch_pos_ids.data(), INFINI_DTYPE_U32, {ntok});
} else {
pos_ids_buf = Tensor::buffer(INFINI_DTYPE_U32, {ntok}, rsrc.memory_pool);
RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok,
INFINIRT_MEMCPY_H2D, stream));
}
for (uint32_t i = 0; i < ntok; i++) {
RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d),
weights->w_in_embd->data(tokens[i] * d),
dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream));
}
// Attention
// attention inner
size_t max_qk_size = 0;
size_t max_seq_len = 0;
size_t max_total_len = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
max_qk_size = std::max(max_qk_size, size_t(seq_len * total_len));
max_seq_len = std::max(max_seq_len, size_t(seq_len));
max_total_len = std::max(max_total_len, size_t(total_len));
}
auto full_k_buf = Tensor::buffer(dt_logits, {max_total_len, nh * d_qk}, rsrc.memory_pool);
auto kv_b_buf = Tensor::buffer(dt_logits, {max_total_len, nh * (d_nope + d_v)}, rsrc.memory_pool);
auto attn_score_buf = Tensor::buffer(dt_logits, {nh, max_qk_size}, rsrc.memory_pool);
auto attn_val_buf = Tensor::buffer(dt_logits, {nh, max_seq_len, d_v}, rsrc.memory_pool);
// Compute
for (uint32_t layer = 0; layer < nlayer; layer++) {
// 1. Attention
// rms norm
rmsnorm(logits_out, logits_in, weights->w_layers[layer].mla_norm, meta.epsilon);
// q_proj
dequant_linear(q_a_buf, logits_out,
weights->w_layers[layer].mla->q_a_proj->w,
weights->w_layers[layer].mla->q_a_proj->s,
weights->w_layers[layer].mla->q_a_proj->z,
1.0, 0.0, nullptr, nullptr);
rmsnorm(q_a_buf, q_a_buf, weights->w_layers[layer].mla->q_a_norm, meta.epsilon);
dequant_linear(q_buf, q_a_buf,
weights->w_layers[layer].mla->q_b_proj->w,
weights->w_layers[layer].mla->q_b_proj->s,
weights->w_layers[layer].mla->q_b_proj->z,
1.0, 0.0, nullptr, nullptr);
auto q_rot = q_buf->view({ntok, nh, d_qk})->slice(2, d_nope, d_rope);
rope_v2(q_rot, q_rot, pos_ids_buf, weights->sin_table, weights->cos_table);
// kv_proj
dequant_linear(kv_a_buf, logits_out,
weights->w_layers[layer].mla->kv_a_proj->w,
weights->w_layers[layer].mla->kv_a_proj->s,
weights->w_layers[layer].mla->kv_a_proj->z,
1.0, 0.0, nullptr, nullptr);
auto kv_pass = kv_a_buf->slice(1, 0, r_kv);
rmsnorm(kv_pass, kv_pass, weights->w_layers[layer].mla->kv_a_norm, meta.epsilon);
auto k_rot = kv_a_buf->slice(1, r_kv, d_rope)->view({ntok, 1, d_rope});
rope_v2(k_rot, k_rot, pos_ids_buf, weights->sin_table, weights->cos_table);
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto past_len = req_pos[req];
auto seq_len = req_lens[req];
auto total_len = past_len + seq_len;
auto o_req = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nh, d_v});
auto q_req = q_buf->slice({{0, token_offset, seq_len}});
auto kv_a_req = kv_a_buf->slice({{0, token_offset, seq_len}});
auto kv_pass_req = kv_a_req->slice(1, 0, r_kv);
auto k_rot_req = kv_a_req->slice(1, r_kv, d_rope);
// concat cache
rearrange(caches[req]->kv_pass[idev][layer]->slice(0, past_len, seq_len), kv_pass_req);
rearrange(caches[req]->k_rot[idev][layer]->slice(0, past_len, seq_len), k_rot_req);
// kv_b_proj
auto kv_b_req = kv_b_buf->slice(0, 0, total_len);
dequant_linear(kv_b_req, caches[req]->kv_pass[idev][layer]->slice(0, 0, total_len),
weights->w_layers[layer].mla->kv_b_proj->w,
weights->w_layers[layer].mla->kv_b_proj->s,
weights->w_layers[layer].mla->kv_b_proj->z,
1.0, 0.0, nullptr, nullptr);
auto full_v_req = kv_b_req->slice(1, nh * d_nope, nh * d_v);
// concat k
auto full_k_req = full_k_buf->slice(0, 0, total_len);
auto full_k_pass_req = full_k_req->slice(1, 0, nh * d_nope);
auto full_k_rot_req = full_k_req->slice(1, nh * d_nope, nh * d_rope);
rearrange(full_k_pass_req, kv_b_req->slice(1, 0, nh * d_nope));
rearrange(full_k_rot_req->view({total_len, nh, d_rope}), k_rot_req->view_as({total_len, nh, d_rope}, {ptrdiff_t(d_rope), 0, 1})); // expand k_rot
// self attention
auto attn_score_req = attn_score_buf->slice(1, 0, seq_len * total_len)->view({nh, seq_len, total_len});
linear(attn_score_req,
q_req->view({seq_len, nh, d_qk})->permute({1, 0, 2}),
full_k_req->view({total_len, nh, d_qk})->permute({1, 2, 0}),
1.f / float(sqrt(d_qk)), 0.f, nullptr, nullptr);
// softmax
causalSoftmax(attn_score_req, attn_score_req);
// attn val
auto attn_val_req = attn_val_buf->slice(1, 0, seq_len)->view({nh, seq_len, d_v});
linear(attn_val_req, attn_score_req, full_v_req->view({total_len, nh, d_v})->permute({1, 0, 2}), 1.f, 0.f, nullptr, nullptr);
// rearrange attn val
rearrange(o_req, attn_val_req->permute({1, 0, 2}));
token_offset += seq_len;
}
// o_proj
dequant_linear(logits_in, o_buf,
weights->w_layers[layer].mla->o_proj->w,
weights->w_layers[layer].mla->o_proj->s,
weights->w_layers[layer].mla->o_proj->z,
1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
// All_reduce if distributed
if (rsrc.comm != nullptr) {
RUN_INFINI(infinicclAllReduce(
logits_in->data(), logits_in->data(), ntok * d, dt_logits,
INFINICCL_SUM, rsrc.comm, stream));
RUN_INFINI(infinirtStreamSynchronize(stream));
}
// 2. MLP
rmsnorm(logits_out, logits_in, weights->w_layers[layer].mlp_norm, meta.epsilon);
if (layer < n_dense_layer) {
auto gate_dense = Tensor::buffer(dt_logits, {ntok, di}, rsrc.memory_pool);
auto up_dense = Tensor::buffer(dt_logits, {ntok, di}, rsrc.memory_pool);
dequant_linear(gate_dense, logits_out,
weights->w_layers[layer].dense_mlp->gate->w,
weights->w_layers[layer].dense_mlp->gate->s,
weights->w_layers[layer].dense_mlp->gate->z, 1.0, 0.0, nullptr, nullptr);
dequant_linear(up_dense, logits_out,
weights->w_layers[layer].dense_mlp->up->w,
weights->w_layers[layer].dense_mlp->up->s,
weights->w_layers[layer].dense_mlp->up->z, 1.0, 0.0, nullptr, nullptr);
swiglu(gate_dense, up_dense, gate_dense);
dequant_linear(logits_in, gate_dense,
weights->w_layers[layer].dense_mlp->down->w,
weights->w_layers[layer].dense_mlp->down->s,
weights->w_layers[layer].dense_mlp->down->z,
1.0, 0.0, idev == 0 ? logits_in : nullptr, nullptr); // only rank 0 adds residual
} else {
// ------------------------------------------------------------------------ //
// 后面几层,用的 稀疏MLP //
// ------------------------------------------------------------------------ //
// 需要提前申请的缓存,给每个MLP使用
auto moe_gate_buf = Tensor::buffer(dt_logits, {ntok, meta.di_moe}, rsrc.memory_pool);
auto moe_up_buf = Tensor::buffer(dt_logits, {ntok, meta.di_moe}, rsrc.memory_pool);
// 需要提前申请的缓存
std::shared_ptr<Tensor> shared_states = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); // 用于存储共享专家的输出
std::shared_ptr<Tensor> router_states_sum = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); // 用于存储路由专家的加权输出
// 需要提前申请的缓存
std::shared_ptr<Tensor> router_logits = Tensor::buffer(dt_logits, {ntok, meta.nexperts}, rsrc.memory_pool); // nx256,路由专家的权重
std::shared_ptr<Tensor> values_gpu = Tensor::buffer(infiniDtype_t::INFINI_DTYPE_F32, {ntok * 8}, rsrc.memory_pool); // 用于存储topkrouter的输出,每个expert对应的加权权重。
std::shared_ptr<Tensor> indices_gpu = Tensor::buffer(infiniDtype_t::INFINI_DTYPE_I32, {ntok * 8}, rsrc.memory_pool); // 用于存储topkrouter的输出,要经过哪些专家id(从256个中选8个)
std::vector<float> values_cpu(ntok * 8, 0.f); // 用于存储topkrouter的输出,每个expert对应的加权权重。(从256个中选8个)
std::vector<int> indices_cpu(ntok * 8, 0); // 用于存储topkrouter的输出,要经过哪些专家的索引。
// config 参数
float routed_scaling_factor = meta.routed_scale; // config.json的超参"routed_scaling_factor",是固定值 2.5
size_t topk = 8; // config.json的超参"num_experts_per_tok", 是固定值 8
// 明确输入输出变量
std::shared_ptr<Tensor> hidden_states = logits_out; // logits_out 是整个 MoE的输入,重新起名字为 hidden_states
// ------------------------------------------------------------------------ //
// 开始计算 //
// ------------------------------------------------------------------------ //
// (1) 共享专家: hidden_states 经过一个共享专家
{
// 输入: hidden_states
// 输出: shared_states
dequant_linear(moe_gate_buf, hidden_states,
weights->w_layers[layer].share_expert->gate->w,
weights->w_layers[layer].share_expert->gate->s,
weights->w_layers[layer].share_expert->gate->z, 1.0, 0.0, nullptr, nullptr);
dequant_linear(moe_up_buf, hidden_states,
weights->w_layers[layer].share_expert->up->w,
weights->w_layers[layer].share_expert->up->s,
weights->w_layers[layer].share_expert->up->z, 1.0, 0.0, nullptr, nullptr);
swiglu(moe_gate_buf, moe_up_buf, moe_gate_buf);
dequant_linear(shared_states, moe_gate_buf,
weights->w_layers[layer].share_expert->down->w,
weights->w_layers[layer].share_expert->down->s,
weights->w_layers[layer].share_expert->down->z, 1.0, 0.0, nullptr, nullptr); // only rank 0 adds residual
}
// (2) topk操作: hidden_states 经过 topkrouter
{
// 输入: hidden_states
// 输出: values_cpu,indices_cpu
auto gate_weight = weights->w_layers[layer].route->w;
gemm(router_logits, hidden_states, gate_weight, 1.0, 0.0); // 非量化的版本
auto gate_correction_bias = weights->w_layers[layer].route->b;
topkrouter(values_gpu, indices_gpu, router_logits, gate_correction_bias, routed_scaling_factor, topk);
RUN_INFINI(infinirtMemcpy((void *)values_cpu.data(), values_gpu->data(), values_cpu.size() * sizeof(float), INFINIRT_MEMCPY_D2H));
RUN_INFINI(infinirtMemcpy((void *)indices_cpu.data(), indices_gpu->data(), indices_cpu.size() * sizeof(int), INFINIRT_MEMCPY_D2H));
}
// (3) MoE操作: hidden_states经过一个8个路由专家
// 输入: hidden_states, values_cpu,indices_cpu
// 输出: router_states_sum
for (size_t itok = 0; itok < ntok; ++itok) { // 先遍历每一个token,再遍历该toekn经过对应的专家
std::shared_ptr<Tensor> hidden_states_i = hidden_states->slice(0, itok, 1);
std::shared_ptr<Tensor> router_states_sum_i = router_states_sum->slice(0, itok, 1);
std::shared_ptr<Tensor> moe_gate_buf_i = moe_gate_buf->slice(0, itok, 1);
std::shared_ptr<Tensor> moe_up_buf_i = moe_up_buf->slice(0, itok, 1);
// 经过第一个专家 : C = alpha * AB
{
// 输入: hidden_states
// 输出: router_states_sum_i
int index = indices_cpu[itok * topk];
float alpha = values_cpu[itok * topk];
dequant_linear(moe_gate_buf_i, hidden_states_i,
weights->w_layers[layer].experts[index]->gate->w,
weights->w_layers[layer].experts[index]->gate->s,
weights->w_layers[layer].experts[index]->gate->z, 1.0, 0.0, nullptr, nullptr);
dequant_linear(moe_up_buf_i, hidden_states_i,
weights->w_layers[layer].experts[index]->up->w,
weights->w_layers[layer].experts[index]->up->s,
weights->w_layers[layer].experts[index]->up->z, 1.0, 0.0, nullptr, nullptr);
swiglu(moe_gate_buf_i, moe_up_buf_i, moe_gate_buf_i);
dequant_linear(router_states_sum_i, moe_gate_buf_i,
weights->w_layers[layer].experts[index]->down->w,
weights->w_layers[layer].experts[index]->down->s,
weights->w_layers[layer].experts[index]->down->z, alpha, 0.0, nullptr, nullptr); // only rank 0 adds residual
}
// 经过后续的专家 : C = alpha * AB + C_last
for (size_t k = 1; k < topk; ++k) {
int index = indices_cpu[itok * topk + k];
float alpha = values_cpu[itok * topk + k];
dequant_linear(moe_gate_buf_i, hidden_states_i,
weights->w_layers[layer].experts[index]->gate->w,
weights->w_layers[layer].experts[index]->gate->s,
weights->w_layers[layer].experts[index]->gate->z, 1.0, 0.0, nullptr, nullptr);
dequant_linear(moe_up_buf_i, hidden_states_i,
weights->w_layers[layer].experts[index]->up->w,
weights->w_layers[layer].experts[index]->up->s,
weights->w_layers[layer].experts[index]->up->z, 1.0, 0.0, nullptr, nullptr);
swiglu(moe_gate_buf_i, moe_up_buf_i, moe_gate_buf_i);
dequant_linear(router_states_sum_i, moe_gate_buf_i,
weights->w_layers[layer].experts[index]->down->w,
weights->w_layers[layer].experts[index]->down->s,
weights->w_layers[layer].experts[index]->down->z, alpha, 0.0, router_states_sum_i, nullptr); // only rank 0 adds residual
}
}
// (4) 最后两个类型的专家求和
// 输入: 共享专家结果shared_states, 路由专家结果router_states_sum
// 输出: logits_out
add(shared_states, shared_states, router_states_sum);
// (5) 最后的残差连接
add(logits_in, shared_states, logits_in);
}
// All_reduce if distributed
if (rsrc.comm != nullptr) {
RUN_INFINI(infinicclAllReduce(
logits_in->data(), logits_in->data(), ntok * d, dt_logits,
INFINICCL_SUM, rsrc.comm, stream));
RUN_INFINI(infinirtStreamSynchronize(stream));
}
}
// Sample and Output
if (idev == 0) {
if (last_logits != nullptr) {
rmsnorm(logits_out, logits_in, weights->w_out_norm, meta.epsilon);
auto last_logits_buf = Tensor::buffer(dt_logits, {ntok, dvoc}, rsrc.memory_pool);
linear(last_logits_buf, logits_out, weights->w_out_embd, 1.0, 0.0, nullptr, nullptr);
RUN_INFINI(infinirtStreamSynchronize(stream));
RUN_INFINI(infinirtMemcpy(last_logits, last_logits_buf->data(), dsize(dt_logits) * ntok * dvoc, INFINIRT_MEMCPY_D2H));
}
if (output != nullptr) {
size_t token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
token_offset += seq_len;
rmsnorm(logits_out->slice(0, req, 1),
logits_in->slice(0, token_offset - 1, 1),
weights->w_out_norm,
meta.epsilon);
}
linear(prob_buf, logits_out->slice(0, 0, nreq), weights->w_out_embd, 1.0, 0.0, nullptr, nullptr);
std::random_device _rd;
std::mt19937 gen(_rd());
token_offset = 0;
for (uint32_t req = 0; req < nreq; req++) {
auto seq_len = req_lens[req];
float random_val = std::uniform_real_distribution<float>(0, 1)(gen);
randomSample(result_buf->slice(0, req, 1)->view_as({}, {}),
prob_buf->slice(0, req, 1)->view_as({dvoc}, {1}),
random_val, topp[req], topk[req], temperature[req]);
token_offset += seq_len;
}
RUN_INFINI(infinirtStreamSynchronize(stream));
RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
for (uint32_t req = 0; req < nreq; req++) {
output[req] = uint32_t(result_cpu[req]);
}
}
}
}
__C void
inferBatchDeepSeekV3(struct DeepSeekV3Model *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct DeepSeekV3Cache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output) {
model->req.tokens = tokens;
model->req.ntok = ntok;
model->req.req_lens = req_lens;
model->req.nreq = nreq;
model->req.req_pos = req_pos;
model->req.kv_caches = kv_caches;
model->req.output = output;
model->req.logits = nullptr;
model->req.temperature = temperature;
model->req.topk = topk;
model->req.topp = topp;
for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
model->states[idev].proceed = true;
lock.unlock();
model->states[idev].cv_start.notify_one();
}
for (size_t i = model->dev_ids.size(); i > 0; i--) {
auto idev = i - 1;
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
lock.unlock();
}
}
__C void
forwardBatchDeepSeekV3(struct DeepSeekV3Model *model,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct DeepSeekV3Cache **kv_caches,
void *logits) {
model->req.tokens = tokens;
model->req.ntok = ntok;
model->req.req_lens = req_lens;
model->req.nreq = nreq;
model->req.req_pos = req_pos;
model->req.kv_caches = kv_caches;
model->req.output = nullptr;
model->req.logits = logits;
model->req.temperature = nullptr;
model->req.topk = nullptr;
model->req.topp = nullptr;
for (size_t idev = 0; idev < model->dev_ids.size(); idev++) {
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
model->states[idev].proceed = true;
lock.unlock();
model->states[idev].cv_start.notify_one();
}
for (size_t i = model->dev_ids.size(); i > 0; i--) {
auto idev = i - 1;
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); });
lock.unlock();
}
}
void launchDevice(const DeepSeekV3Meta &meta, std::shared_ptr<DeepSeekV3DeviceWeights> weights, DeepSeekV3DeviceResource *rsrc, InferState &state, InferRequest &req,
infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) {
// Create Device Resource
createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm);
CacheManager cache_manager(100);
InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream);
// Set the inference context for this thread
setInferenceContext(&ctx);
{
std::unique_lock<std::mutex> lock(state.mtx);
state.loaded = true;
lock.unlock();
state.cv_load.notify_one();
}
// Infer Loop
while (true) {
std::unique_lock<std::mutex> lock(state.mtx);
state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; });
// quit if exit_flag is set
if (state.exit_flag) {
break;
}
inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok,
req.req_lens, req.nreq, req.req_pos, req.kv_caches,
req.temperature, req.topk, req.topp, req.output, req.logits);
state.proceed = false;
lock.unlock();
state.cv_done.notify_one();
}
// Clean-Up
releaseDeviceResource(*rsrc);
setInferenceContext(nullptr); // Clear the context when done
}
DeepSeekV3Model::DeepSeekV3Model(const DeepSeekV3Meta *_meta, const DeepSeekV3Weights *weights) : meta(*_meta) {
auto device_weights = weights->device_weights;
int ndev = device_weights.size();
device = device_weights[0]->device;
dev_ids.resize(ndev);
for (int i = 0; i < ndev; i++) {
dev_ids[i] = device_weights[i]->dev_id;
}
dev_resources = std::vector<DeepSeekV3DeviceResource>(ndev);
states = std::vector<InferState>(ndev);
threads.resize(ndev);
RUN_INFINI(infinirtInit());
auto comms = std::vector<infinicclComm_t>(ndev, nullptr);
if (ndev > 1) {
RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data()));
}
for (int i = 0; i < ndev; i++) {
threads[i] = std::thread(launchDevice, std::cref(meta), device_weights[i], &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]);
}
for (int i = 0; i < ndev; i++) {
std::unique_lock<std::mutex> lock(states[i].mtx);
states[i].cv_load.wait(lock, [&] { return states[i].loaded; });
lock.unlock();
}
}
__C struct DeepSeekV3Model *
createDeepSeekV3Model(const DeepSeekV3Meta *_meta,
const DeepSeekV3Weights *weights) {
DeepSeekV3Model *model = new DeepSeekV3Model(_meta, weights);
return model;
}
__C void
destroyDeepSeekV3Model(struct DeepSeekV3Model *model) {
auto ndev = model->dev_resources.size();
for (size_t idev = 0; idev < ndev; idev++) {
std::unique_lock<std::mutex> lock(model->states[idev].mtx);
model->states[idev].exit_flag = true;
lock.unlock();
model->states[idev].cv_start.notify_one();
}
for (size_t idev = 0; idev < ndev; idev++) {
model->threads[idev].join();
}
delete model;
}
#include "deepseek_v3_impl.hpp"
__C struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *model) {
DeepSeekV3Cache *cache = new DeepSeekV3Cache();
auto ndev = model->dev_resources.size();
auto nlayer = model->meta.n_dense_layer + model->meta.n_sparse_layer;
auto max_len = model->meta.dctx;
auto d_rope = model->meta.d_rope;
auto r_kv = model->meta.r_kv;
auto kv_pass_shape = std::vector<size_t>{max_len, r_kv};
auto k_rot_shape = std::vector<size_t>{max_len, d_rope};
for (size_t idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
auto kv_pass_cache = std::vector<std::shared_ptr<Tensor>>();
auto k_rot_cache = std::vector<std::shared_ptr<Tensor>>();
for (size_t layer = 0; layer < nlayer; layer++) {
kv_pass_cache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, kv_pass_shape)));
k_rot_cache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, k_rot_shape)));
}
cache->kv_pass.push_back(kv_pass_cache);
cache->k_rot.push_back(k_rot_cache);
}
return cache;
}
__C void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *model,
struct DeepSeekV3Cache *cache) {
auto ndev = model->dev_resources.size();
auto nlayer = model->meta.n_dense_layer + model->meta.n_sparse_layer;
for (size_t idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
for (size_t layer = 0; layer < nlayer; layer++) {
cache->kv_pass[idev][layer].reset();
cache->k_rot[idev][layer].reset();
}
}
delete cache;
}
\ No newline at end of file
#ifndef DEEPSEEK_V3_IMPL_H
#define DEEPSEEK_V3_IMPL_H
#include "infinicore_infer.h"
#include "../../allocator.hpp"
#include "../../tensor.hpp"
#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
struct QuantLinearWeight {
std::shared_ptr<Tensor> w;
std::shared_ptr<Tensor> s;
std::shared_ptr<Tensor> z;
};
struct MLAWeight {
std::shared_ptr<Tensor> kv_a_norm, q_a_norm;
std::shared_ptr<QuantLinearWeight> kv_a_proj, kv_b_proj, o_proj, q_a_proj, q_b_proj;
};
struct GateWeight {
std::shared_ptr<Tensor> w;
std::shared_ptr<Tensor> b;
};
struct MLPWeight {
std::shared_ptr<QuantLinearWeight> gate, up, down;
};
struct LayerWeight {
std::shared_ptr<Tensor> mla_norm;
std::shared_ptr<MLAWeight> mla;
std::shared_ptr<Tensor> mlp_norm;
std::shared_ptr<MLPWeight> dense_mlp;
std::shared_ptr<GateWeight> route;
std::shared_ptr<MLPWeight> share_expert;
std::vector<std::shared_ptr<MLPWeight>> experts;
};
struct DeepSeekV3DeviceWeights {
std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
cos_table;
std::vector<LayerWeight> w_layers;
infiniDevice_t device;
int dev_id;
infinirtStream_t load_stream;
};
struct DeepSeekV3Weights {
std::vector<std::shared_ptr<DeepSeekV3DeviceWeights>> device_weights;
DeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device,
int ndev,
const int *dev_ids);
};
struct DeepSeekV3DeviceResource {
// Device
infiniDevice_t device;
int device_id;
infiniopHandle_t handle;
// Weights
std::shared_ptr<DeepSeekV3DeviceWeights> weights;
// Streams
infinirtStream_t stream;
// Communicator
infinicclComm_t comm;
std::shared_ptr<MemoryPool> memory_pool;
};
struct InferState {
std::mutex mtx;
std::condition_variable cv_load, cv_start, cv_done;
bool loaded = false;
bool proceed = false;
bool exit_flag = false;
};
struct InferRequest {
const uint32_t *tokens;
uint32_t ntok;
const uint32_t *req_lens;
uint32_t nreq;
const uint32_t *req_pos;
struct DeepSeekV3Cache **kv_caches;
const float *temperature;
const uint32_t *topk;
const float *topp;
uint32_t *output;
void *logits;
};
struct DeepSeekV3Model {
DeepSeekV3Meta meta;
infiniDevice_t device;
std::vector<int> dev_ids;
std::vector<DeepSeekV3DeviceResource> dev_resources;
std::vector<InferState> states;
std::vector<std::thread> threads;
InferRequest req;
DeepSeekV3Model(const DeepSeekV3Meta *, const DeepSeekV3Weights *weights);
};
struct DeepSeekV3Cache {
std::vector<std::vector<std::shared_ptr<Tensor>>> kv_pass, k_rot;
};
#endif
#include "deepseek_v3_impl.hpp"
#include <cmath>
inline std::shared_ptr<Tensor> getInEmbd(
const DeepSeekV3Meta *meta) {
auto shape = std::vector<size_t>({meta->dvoc, meta->d});
return Tensor::weight(nullptr, meta->dt_logits, shape);
}
inline std::shared_ptr<Tensor> getOutNorm(
const DeepSeekV3Meta *meta) {
auto shape = std::vector<size_t>({meta->d});
return Tensor::weight(nullptr, meta->dt_norm, shape);
}
inline std::shared_ptr<Tensor> getOutEmbd(
const DeepSeekV3Meta *meta) {
auto shape = std::vector<size_t>({meta->dvoc, meta->d});
return Tensor::weight(nullptr, meta->dt_logits, shape)
->permute({1, 0});
}
inline std::shared_ptr<Tensor> getMLANorm(
const DeepSeekV3Meta *meta) {
auto shape = std::vector<size_t>({meta->d});
return Tensor::weight(nullptr, meta->dt_norm, shape);
}
inline std::shared_ptr<QuantLinearWeight> getQuantLinear(
const DeepSeekV3Meta *meta, size_t in_dim, size_t out_dim) {
auto qw = std::make_shared<QuantLinearWeight>();
auto shape_w = std::vector<size_t>({in_dim, out_dim / 8});
qw->w = Tensor::weight(nullptr, INFINI_DTYPE_I32, shape_w);
qw->s = Tensor::weight(nullptr, meta->dt_quant_scale, {in_dim / 64, out_dim});
qw->z = Tensor::weight(nullptr, INFINI_DTYPE_I32, {in_dim / 64, out_dim / 8});
return qw;
}
// ------------------- MLA Weights -------------------
inline std::shared_ptr<Tensor> getMLPNorm(
const DeepSeekV3Meta *meta) {
auto shape = std::vector<size_t>({meta->d});
return Tensor::weight(nullptr, meta->dt_norm, shape);
}
inline std::shared_ptr<MLAWeight> getMLA(const DeepSeekV3Meta *meta, int ndev) {
auto mla = std::make_shared<MLAWeight>();
mla->q_a_proj = getQuantLinear(meta, meta->d, meta->r_q);
mla->q_a_norm = Tensor::weight(nullptr, meta->dt_norm, {meta->r_q});
mla->q_b_proj = getQuantLinear(meta, meta->r_q, meta->nh / ndev * meta->d_qk);
mla->kv_a_proj = getQuantLinear(meta, meta->d, meta->r_kv + meta->d_rope);
mla->kv_a_norm = Tensor::weight(nullptr, meta->dt_norm, {meta->r_kv});
mla->kv_b_proj = getQuantLinear(meta, meta->r_kv, meta->nh / ndev * (meta->d_nope + meta->d_v));
mla->o_proj = getQuantLinear(meta, meta->nh / ndev * meta->d_v, meta->d);
return mla;
}
// ------------------- Dense MLP -------------------
inline std::shared_ptr<MLPWeight> getMLP(const DeepSeekV3Meta *meta, size_t d, size_t di) {
auto mlp = std::make_shared<MLPWeight>();
mlp->gate = getQuantLinear(meta, d, di);
mlp->up = getQuantLinear(meta, d, di);
mlp->down = getQuantLinear(meta, di, d);
return mlp;
}
inline std::shared_ptr<MLPWeight> getDenseMLP(const DeepSeekV3Meta *meta, int ndev) {
return getMLP(meta, meta->d, meta->di / ndev);
}
// ------------------- Sparse Route + Experts -------------------
inline std::shared_ptr<GateWeight> getRouteWeight(
const DeepSeekV3Meta *meta) {
auto gw = std::make_shared<GateWeight>();
gw->w = Tensor::weight(nullptr, meta->dt_gate_weight, {meta->nexperts, meta->d})->permute({1, 0});
gw->b = Tensor::weight(nullptr, meta->dt_gate_bias, {meta->nexperts});
return gw;
}
inline std::shared_ptr<MLPWeight> getShareExpert(const DeepSeekV3Meta *meta, int ndev) {
return getMLP(meta, meta->d, meta->di_moe / ndev);
}
inline std::vector<std::shared_ptr<MLPWeight>> getExperts(const DeepSeekV3Meta *meta, int ndev) {
std::vector<std::shared_ptr<MLPWeight>> experts(meta->nexperts);
for (size_t i = 0; i < meta->nexperts; i++) {
experts[i] = getMLP(meta, meta->d, meta->di_moe / ndev);
}
return experts;
}
inline std::shared_ptr<Tensor> getSinTable(const DeepSeekV3Meta *meta) {
auto half_dh = meta->d_rope / 2;
auto unit = dsize(meta->dt_logits);
void *table = std::malloc(meta->dctx * half_dh * unit);
for (size_t i = 0; i < meta->dctx; i++) {
for (size_t j = 0; j < half_dh; j++) {
float _sin = std::sin(
static_cast<float>(i) / std::pow(meta->rope_theta, static_cast<float>(j) / half_dh));
if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_sin);
} else if (meta->dt_logits == INFINI_DTYPE_BF16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_sin);
} else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _sin;
} else {
std::cout << "unsupported data type" << std::endl;
exit(1);
}
}
}
auto shape = std::vector<size_t>({meta->dctx, half_dh});
auto tensor = Tensor::weight(table, meta->dt_logits, shape);
std::free(table);
return tensor;
}
inline std::shared_ptr<Tensor> getCosTable(const DeepSeekV3Meta *meta) {
auto half_dh = meta->d_rope / 2;
auto unit = dsize(meta->dt_logits);
void *table = std::malloc(meta->dctx * half_dh * unit);
for (size_t i = 0; i < meta->dctx; i++) {
for (size_t j = 0; j < half_dh; j++) {
float _cos = std::cos(
static_cast<float>(i) / std::pow(meta->rope_theta, static_cast<float>(j) / half_dh));
if (meta->dt_logits == INFINI_DTYPE_F16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_f16(_cos);
} else if (meta->dt_logits == INFINI_DTYPE_BF16) {
((uint16_t *)table)[i * half_dh + j] = f32_to_bf16(_cos);
} else if (meta->dt_logits == INFINI_DTYPE_F32) {
((float *)table)[i * half_dh + j] = _cos;
} else {
std::cout << "unsupported data type" << std::endl;
exit(1);
}
}
}
auto shape = std::vector<size_t>({meta->dctx, half_dh});
auto tensor = Tensor::weight(table, meta->dt_logits, shape);
std::free(table);
return tensor;
}
DeepSeekV3Weights::DeepSeekV3Weights(
const DeepSeekV3Meta *meta, infiniDevice_t device, int ndev, const int *dev_ids) {
device_weights = std::vector<std::shared_ptr<DeepSeekV3DeviceWeights>>(ndev);
for (int dev = 0; dev < ndev; dev++) {
int dev_id = dev_ids[dev];
RUN_INFINI(infinirtSetDevice(device, dev_id));
device_weights[dev] = std::make_shared<DeepSeekV3DeviceWeights>();
device_weights[dev]->device = device;
device_weights[dev]->dev_id = dev_id;
RUN_INFINI(infinirtStreamCreate(&device_weights[dev]->load_stream));
device_weights[dev]->w_in_embd = getInEmbd(meta);
device_weights[dev]->w_out_norm = getOutNorm(meta);
device_weights[dev]->w_out_embd = getOutEmbd(meta);
device_weights[dev]->sin_table = getSinTable(meta);
device_weights[dev]->cos_table = getCosTable(meta);
device_weights[dev]->w_layers = std::vector<LayerWeight>(meta->n_dense_layer + meta->n_sparse_layer);
for (size_t layer = 0; layer < meta->n_dense_layer + meta->n_sparse_layer; layer++) {
device_weights[dev]->w_layers[layer].mla_norm = getMLANorm(meta);
device_weights[dev]->w_layers[layer].mla = getMLA(meta, ndev);
device_weights[dev]->w_layers[layer].mlp_norm = getMLPNorm(meta);
if (layer < meta->n_dense_layer) {
device_weights[dev]->w_layers[layer].dense_mlp = getDenseMLP(meta, ndev);
} else {
device_weights[dev]->w_layers[layer].route = getRouteWeight(meta);
device_weights[dev]->w_layers[layer].share_expert = getShareExpert(meta, ndev);
device_weights[dev]->w_layers[layer].experts = getExperts(meta, ndev);
}
}
}
}
// --- Global
void load_input_embd(DeepSeekV3Weights *weights, void *cpu_ptr) {
std::cout << "Loading input embedding from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_in_embd->load(cpu_ptr, weight->load_stream);
}
}
void load_output_norm(DeepSeekV3Weights *weights, void *cpu_ptr) {
std::cout << "Loading output norm from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_out_norm->load(cpu_ptr, weight->load_stream);
}
}
void load_output_embd(DeepSeekV3Weights *weights, void *cpu_ptr) {
std::cout << "Loading output embedding from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_out_embd->load(cpu_ptr, weight->load_stream);
}
}
// --- Attention
void load_attn_norm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
std::cout << "Loading attention norm " << layer << " from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].mla_norm->load(cpu_ptr, weight->load_stream);
}
}
void load_attn_q_a_proj(DeepSeekV3Weights *weights,
void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
std::cout << "Loading attention q_a_proj " << layer << " from " << weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].mla->q_a_proj->w->load(weight_ptr, weight->load_stream);
weight->w_layers[layer].mla->q_a_proj->s->load(scale_ptr, weight->load_stream);
weight->w_layers[layer].mla->q_a_proj->z->load(zero_ptr, weight->load_stream);
}
}
void load_attn_q_a_layernorm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
std::cout << "Loading attention q_a_layernorm " << layer << " from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].mla->q_a_norm->load(cpu_ptr, weight->load_stream);
}
}
inline void load_dist_linear(void *w_ptr, void *s_ptr, void *z_ptr, std::shared_ptr<Tensor> w, std::shared_ptr<Tensor> s, std::shared_ptr<Tensor> z, size_t ndev, size_t dev, infinirtStream_t stream) {
auto w_offset = w->shape()[0] * w->shape()[1] / ndev * dev * dsize(w->dtype());
auto s_offset = s->shape()[0] * s->shape()[1] / ndev * dev * dsize(s->dtype());
auto z_offset = z->shape()[0] * z->shape()[1] / ndev * dev * dsize(z->dtype());
w->load(reinterpret_cast<char *>(w_ptr) + w_offset, stream);
s->load(reinterpret_cast<char *>(s_ptr) + s_offset, stream);
z->load(reinterpret_cast<char *>(z_ptr) + z_offset, stream);
}
void load_attn_q_b_proj(DeepSeekV3Weights *weights,
void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
std::cout << "Loading attention q_b_proj " << layer << " from " << weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
auto w = weight->w_layers[layer].mla->q_b_proj->w;
auto s = weight->w_layers[layer].mla->q_b_proj->s;
auto z = weight->w_layers[layer].mla->q_b_proj->z;
load_dist_linear(weight_ptr, scale_ptr, zero_ptr, w, s, z, weights->device_weights.size(), dev, weight->load_stream);
}
}
void load_attn_kv_a_proj_with_mqa(DeepSeekV3Weights *weights,
void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
std::cout << "Loading attention kv_a_proj_with_mqa " << layer << " from " << weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].mla->kv_a_proj->w->load(weight_ptr, weight->load_stream);
weight->w_layers[layer].mla->kv_a_proj->s->load(scale_ptr, weight->load_stream);
weight->w_layers[layer].mla->kv_a_proj->z->load(zero_ptr, weight->load_stream);
}
}
void load_attn_kv_a_layernorm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
std::cout << "Loading attention kv_a_layernorm " << layer << " from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].mla->kv_a_norm->load(cpu_ptr, weight->load_stream);
}
}
void load_attn_kv_b_proj(DeepSeekV3Weights *weights,
void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
std::cout << "Loading attention kv_b_proj " << layer << " from " << weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
auto w = weight->w_layers[layer].mla->kv_b_proj->w;
auto s = weight->w_layers[layer].mla->kv_b_proj->s;
auto z = weight->w_layers[layer].mla->kv_b_proj->z;
load_dist_linear(weight_ptr, scale_ptr, zero_ptr, w, s, z, weights->device_weights.size(), dev, weight->load_stream);
}
}
void load_attn_o_proj(DeepSeekV3Weights *weights,
void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer) {
std::cout << "Loading attention o_proj " << layer << " from " << weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
auto w = weight->w_layers[layer].mla->o_proj->w;
auto s = weight->w_layers[layer].mla->o_proj->s;
auto z = weight->w_layers[layer].mla->o_proj->z;
load_dist_linear(weight_ptr, scale_ptr, zero_ptr, w, s, z, weights->device_weights.size(), dev, weight->load_stream);
}
}
// --- MLP
void load_mlp_norm(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
std::cout << "Loading mlp norm " << layer << " from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].mlp_norm->load(cpu_ptr, weight->load_stream);
}
}
void load_mlp_dense(DeepSeekV3Weights *weights,
void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
size_t layer_id) {
std::cout << "Loading mlp dense " << layer_id << " from " << gate_weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
auto gate_w = weight->w_layers[layer_id].dense_mlp->gate->w;
auto gate_s = weight->w_layers[layer_id].dense_mlp->gate->s;
auto gate_z = weight->w_layers[layer_id].dense_mlp->gate->z;
auto up_w = weight->w_layers[layer_id].dense_mlp->up->w;
auto up_s = weight->w_layers[layer_id].dense_mlp->up->s;
auto up_z = weight->w_layers[layer_id].dense_mlp->up->z;
auto down_w = weight->w_layers[layer_id].dense_mlp->down->w;
auto down_s = weight->w_layers[layer_id].dense_mlp->down->s;
auto down_z = weight->w_layers[layer_id].dense_mlp->down->z;
load_dist_linear(gate_weight_ptr, gate_scale_ptr, gate_zero_ptr, gate_w, gate_s, gate_z, weights->device_weights.size(), dev, weight->load_stream);
load_dist_linear(up_weight_ptr, up_scale_ptr, up_zero_ptr, up_w, up_s, up_z, weights->device_weights.size(), dev, weight->load_stream);
load_dist_linear(down_weight_ptr, down_scale_ptr, down_zero_ptr, down_w, down_s, down_z, weights->device_weights.size(), dev, weight->load_stream);
}
}
void load_mlp_gate_weight(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
std::cout << "Loading mlp gate weight " << layer << " from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].route->w->load(cpu_ptr, weight->load_stream);
}
}
void load_mlp_gate_bias(DeepSeekV3Weights *weights, void *cpu_ptr, size_t layer) {
std::cout << "Loading mlp gate bias " << layer << " from " << cpu_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
weight->w_layers[layer].route->b->load(cpu_ptr, weight->load_stream);
}
}
void load_mlp_shared_experts(DeepSeekV3Weights *weights,
void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
size_t layer_id) {
std::cout << "Loading mlp shared experts " << layer_id << " from " << gate_weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
auto gate_w = weight->w_layers[layer_id].share_expert->gate->w;
auto gate_s = weight->w_layers[layer_id].share_expert->gate->s;
auto gate_z = weight->w_layers[layer_id].share_expert->gate->z;
auto up_w = weight->w_layers[layer_id].share_expert->up->w;
auto up_s = weight->w_layers[layer_id].share_expert->up->s;
auto up_z = weight->w_layers[layer_id].share_expert->up->z;
auto down_w = weight->w_layers[layer_id].share_expert->down->w;
auto down_s = weight->w_layers[layer_id].share_expert->down->s;
auto down_z = weight->w_layers[layer_id].share_expert->down->z;
load_dist_linear(gate_weight_ptr, gate_scale_ptr, gate_zero_ptr, gate_w, gate_s, gate_z, weights->device_weights.size(), dev, weight->load_stream);
load_dist_linear(up_weight_ptr, up_scale_ptr, up_zero_ptr, up_w, up_s, up_z, weights->device_weights.size(), dev, weight->load_stream);
load_dist_linear(down_weight_ptr, down_scale_ptr, down_zero_ptr, down_w, down_s, down_z, weights->device_weights.size(), dev, weight->load_stream);
}
}
void load_mlp_experts(DeepSeekV3Weights *weights,
void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
size_t layer_id, size_t expert_id) {
std::cout << "Loading mlp expert " << layer_id << " expert " << expert_id
<< " from " << gate_weight_ptr << std::endl;
for (int dev = 0; dev < int(weights->device_weights.size()); dev++) {
auto weight = weights->device_weights[dev];
RUN_INFINI(infinirtSetDevice(weight->device, weight->dev_id));
auto gate_w = weight->w_layers[layer_id].experts[expert_id]->gate->w;
auto gate_s = weight->w_layers[layer_id].experts[expert_id]->gate->s;
auto gate_z = weight->w_layers[layer_id].experts[expert_id]->gate->z;
auto up_w = weight->w_layers[layer_id].experts[expert_id]->up->w;
auto up_s = weight->w_layers[layer_id].experts[expert_id]->up->s;
auto up_z = weight->w_layers[layer_id].experts[expert_id]->up->z;
auto down_w = weight->w_layers[layer_id].experts[expert_id]->down->w;
auto down_s = weight->w_layers[layer_id].experts[expert_id]->down->s;
auto down_z = weight->w_layers[layer_id].experts[expert_id]->down->z;
load_dist_linear(gate_weight_ptr, gate_scale_ptr, gate_zero_ptr, gate_w, gate_s, gate_z, weights->device_weights.size(), dev, weight->load_stream);
load_dist_linear(up_weight_ptr, up_scale_ptr, up_zero_ptr, up_w, up_s, up_z, weights->device_weights.size(), dev, weight->load_stream);
load_dist_linear(down_weight_ptr, down_scale_ptr, down_zero_ptr, down_w, down_s, down_z, weights->device_weights.size(), dev, weight->load_stream);
}
}
static DeepSeekV3WeightLoader weight_loader = {
// Global
.load_input_embd = load_input_embd,
.load_output_norm = load_output_norm,
.load_output_embd = load_output_embd,
// Attention
.load_attn_norm = load_attn_norm,
.load_attn_q_a_proj = load_attn_q_a_proj,
.load_attn_q_a_layernorm = load_attn_q_a_layernorm,
.load_attn_q_b_proj = load_attn_q_b_proj,
.load_attn_kv_a_proj_with_mqa = load_attn_kv_a_proj_with_mqa,
.load_attn_kv_a_layernorm = load_attn_kv_a_layernorm,
.load_attn_kv_b_proj = load_attn_kv_b_proj,
.load_attn_o_proj = load_attn_o_proj,
// MLP
.load_mlp_norm = load_mlp_norm,
.load_mlp_dense = load_mlp_dense,
.load_mlp_gate_weight = load_mlp_gate_weight,
.load_mlp_gate_bias = load_mlp_gate_bias,
.load_mlp_shared_experts = load_mlp_shared_experts,
.load_mlp_experts = load_mlp_experts,
};
__C DeepSeekV3Weights *
createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device,
int ndev,
const int *dev_ids) {
auto weights = new DeepSeekV3Weights(meta, device, ndev, dev_ids);
return weights;
};
__C DeepSeekV3WeightLoader *
createDeepSeekV3WeightLoader() {
return &weight_loader;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment