Unverified Commit 22804eaa authored by blkmjsian's avatar blkmjsian Committed by GitHub
Browse files

[T2-3-1]blkmjsian

- deepseek
- jiuge 4B awq 
parent 5c6000ec
#ifndef INFINICORE_INFER_H
#define INFINICORE_INFER_H
#include "infinicore_infer/models/jiuge.h"
#include "infinicore_infer/cache.h"
#include "infinicore_infer/weights_loader.h"
#include "infinicore_infer/models/deepseek.h"
#include "infinicore_infer/models/jiuge.h"
#endif /* INFINICORE_INFER_H */
#ifndef CACHE_H
#define CACHE_H
#include <infinirt.h>
__C __export struct KVCache *createKVCache(
size_t nlayers,
size_t max_len,
size_t nkvh_,
size_t dk,
size_t dv,
infiniDtype_t dtype,
infiniDevice_t device,
int *dev_ids,
size_t ndev);
__C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len);
__C __export void dropKVCache(KVCache *kv_cache);
#endif /* CACHE_H */
#ifndef DEEPSEEK_V3_WEIGHTS_H
#define DEEPSEEK_V3_WEIGHTS_H
#include <infiniccl.h>
#include <infiniop.h>
#include <infinirt.h>
#include <stddef.h>
#include <stdint.h>
struct DeepSeekV3Weights;
// Function pointer signatures
typedef void (*load_global_fn)(DeepSeekV3Weights *, void *cpu_ptr);
typedef void (*load_layer_fn)(DeepSeekV3Weights *, void *cpu_ptr, size_t layer_id);
typedef void (*load_layer_linear_fn)(DeepSeekV3Weights *, void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer_id);
typedef void (*load_layer_mlp_fn)(
DeepSeekV3Weights *,
void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
size_t layer_id);
typedef void (*load_layer_expert_mlp_fn)(
DeepSeekV3Weights *,
void *gate_weight_ptr, void *gate_scale_ptr, void *gate_zero_ptr,
void *up_weight_ptr, void *up_scale_ptr, void *up_zero_ptr,
void *down_weight_ptr, void *down_scale_ptr, void *down_zero_ptr,
size_t layer_id, size_t expert_id);
// Struct containing all weight loading functions
typedef struct {
// Global
load_global_fn load_input_embd;
load_global_fn load_output_norm;
load_global_fn load_output_embd;
// Attention
load_layer_fn load_attn_norm;
load_layer_linear_fn load_attn_q_a_proj;
load_layer_fn load_attn_q_a_layernorm;
load_layer_linear_fn load_attn_q_b_proj;
load_layer_linear_fn load_attn_kv_a_proj_with_mqa;
load_layer_fn load_attn_kv_a_layernorm;
load_layer_linear_fn load_attn_kv_b_proj;
load_layer_linear_fn load_attn_o_proj;
// MLP
load_layer_fn load_mlp_norm;
// MLP dense part
load_layer_mlp_fn load_mlp_dense;
// MLP sparse gating
load_layer_fn load_mlp_gate_weight;
load_layer_fn load_mlp_gate_bias;
// Shared experts
load_layer_mlp_fn load_mlp_shared_experts;
// Per-expert functions
load_layer_expert_mlp_fn load_mlp_experts;
} DeepSeekV3WeightLoader;
struct DeepSeekV3Model;
typedef struct {
infiniDtype_t dt_logits;
infiniDtype_t dt_norm;
infiniDtype_t dt_quant_weight;
infiniDtype_t dt_quant_scale;
infiniDtype_t dt_quant_zero;
infiniDtype_t dt_gate_weight;
infiniDtype_t dt_gate_bias;
size_t n_sparse_layer;
size_t n_dense_layer;
size_t d;
size_t nh;
size_t nkvh;
size_t d_rope;
size_t d_nope;
size_t r_q;
size_t r_kv;
size_t d_qk;
size_t d_v;
float routed_scale;
size_t nexperts;
size_t kexperts;
size_t di;
size_t di_moe;
size_t dctx;
size_t dvoc;
float epsilon;
float rope_theta;
uint32_t end_token;
} DeepSeekV3Meta;
//////////////////// APIs ///////////////////////
/// @brief 创建模型
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct DeepSeekV3Model *
createDeepSeekV3Model(const DeepSeekV3Meta *,
const DeepSeekV3Weights *);
__C DeepSeekV3Weights *
createDeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device,
int ndev,
const int *dev_ids);
__C __export DeepSeekV3WeightLoader *
createDeepSeekV3WeightLoader();
/// @brief 销毁模型
__C __export void destroyDeepSeekV3Model(struct DeepSeekV3Model *);
__C __export struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *);
__C __export void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *,
struct DeepSeekV3Cache *);
/// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct DeepSeekV3Cache **caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatchDeepSeekV3(struct DeepSeekV3Model *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct DeepSeekV3Cache **caches,
void *logits);
#endif // DEEPSEEK_V3_WEIGHTS_H
\ No newline at end of file
......@@ -61,20 +61,6 @@ createJiugeModel(const JiugeMeta *,
__C __export void
destroyJiugeModel(struct JiugeModel *);
/// @brief 创建 KV Cache
__C __export struct KVCache *
createKVCache(const struct JiugeModel *);
/// @brief 复制 KV Cache
__C __export struct KVCache *
duplicateKVCache(const struct JiugeModel *,
const struct KVCache *, uint32_t seq_len);
/// @brief 销毁 KV Cache
__C __export void
dropKVCache(const struct JiugeModel *,
struct KVCache *);
/// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
......@@ -87,12 +73,12 @@ dropKVCache(const struct JiugeModel *,
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
inferBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
......@@ -103,10 +89,10 @@ inferBatch(struct JiugeModel *,
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatch(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits);
forwardBatchJiuge(struct JiugeModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits);
#endif
#ifndef MODEL_JIUGE_AWQ_H
#define MODEL_JIUGE_AWQ_H
#include <infiniccl.h>
#include <infiniop.h>
#include <infinirt.h>
#include <stdint.h>
#include "../weights_loader.h"
struct JiugeAWQModel;
typedef struct
{
infiniDtype_t dt_logits;
infiniDtype_t dt_linear_w;
infiniDtype_t dt_norm_w;
size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc;
float epsilon, theta;
uint32_t end_token;
size_t nbit;
size_t quant_group_size;
char has_qkv_bias;
} JiugeAWQMeta;
//////////////////// APIs ///////////////////////
__C __export struct ModelWeights *
createJiugeAWQWeights(const JiugeAWQMeta *,
infiniDevice_t device,
int ndev,
const int *dev_ids);
/// @brief 创建模型
/// @param device 协处理器种类
/// @param ndev 协处理器数量
/// @param dev_ids 协处理器编号,长度为 ndev
__C __export struct JiugeAWQModel *
createJiugeAWQModel(const JiugeAWQMeta *,
const ModelWeights *);
/// @brief 销毁模型
__C __export void
destroyJiugeAWQModel(struct JiugeAWQModel *);
/// @brief 批次推理一轮,并采样出新的 token
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param temperature 采样温度(0. 表示贪心采样)
/// @param topk 采样 topk(1 表示贪心采样)
/// @param topp 采样 topp
/// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
inferBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
const float *temperature, const uint32_t *topk, const float *topp,
uint32_t *output);
/// @brief 批次推理一轮,输出 output embedding 后的 logits
/// @param tokens 输入 token 地址
/// @param ntok 输入 token 数量
/// @param nreq 请求数量
/// @param req_lens 每个请求的 token 数量
/// @param req_pos 每个请求的起始位置
/// @param kv_caches 每个请求的 KV Cache
/// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq
__C __export void
forwardBatchJiugeAWQ(struct JiugeAWQModel *,
const uint32_t *tokens, uint32_t ntok,
const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
struct KVCache **kv_caches,
void *logits);
#endif
#ifndef WEIGHTS_LOADER_H
#define WEIGHTS_LOADER_H
#include <infinirt.h>
struct ModelWeights;
__C __export void
loadModelWeight(struct ModelWeights *weights, const char *name, void *data);
__C __export void
loadModelWeightDistributed(struct ModelWeights *weights, const char *name, void *data, int *ranks, int nrank);
#endif // WEIGHTS_LOADER_H
This diff is collapsed.
......@@ -11,8 +11,8 @@ from libinfinicore_infer import (
destroy_jiuge_model,
create_kv_cache,
drop_kv_cache,
infer_batch,
forward_batch,
infer_batch_jiuge,
forward_batch_jiuge,
)
from infer_task import InferTask, KVCache
......@@ -506,13 +506,15 @@ class JiugeForCauslLM:
print(f"Creating model on {ndev} devices...")
load_start_time = time.time()
dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.ndev = ndev
self.device = device
self.model_instance = create_jiuge_model(
byref(self.meta),
byref(self.weights),
device,
ndev,
dev_ids,
self.dev_ids,
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
......@@ -521,15 +523,25 @@ class JiugeForCauslLM:
return self.meta.dctx
def create_kv_cache(self):
return create_kv_cache(self.model_instance)
return create_kv_cache(
self.meta.nlayer,
self.meta.dctx,
self.meta.nkvh,
self.meta.dh,
self.meta.dh,
self.meta.dt_logits,
self.device,
self.dev_ids,
self.ndev,
)
def drop_kv_cache(self, kv_cache):
drop_kv_cache(self.model_instance, kv_cache)
drop_kv_cache(kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))()
batch_inputs = JiugeBatchedTask(tasks)
infer_batch(
infer_batch_jiuge(
self.model_instance,
*(batch_inputs.input_args()),
output,
......@@ -609,7 +621,7 @@ class JiugeForCauslLM:
logits = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
)
forward_batch(
forward_batch_jiuge(
self.model_instance,
batch_inputs.tokens,
batch_inputs.ntok,
......
from typing import List, Sequence
from libinfinicore_infer import (
JiugeAWQMetaCStruct,
KVCacheCStruct,
DataType,
DeviceType,
load_model_weight,
load_model_weight_distributed,
create_jiuge_awq_weights,
create_jiuge_awq_model,
destroy_jiuge_awq_model,
create_kv_cache,
drop_kv_cache,
infer_batch_jiuge_awq,
forward_batch_jiuge_awq,
)
from infer_task import InferTask, KVCache
from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import math
import torch
import transformers
torch.set_default_device("cpu")
class JiugeAWQMetaFromConfig(JiugeAWQMetaCStruct):
def __init__(self, config, dtype=torch.float16, max_tokens=None):
if config["torch_dtype"] == "float16":
dt_ = DataType.INFINI_DTYPE_F16
elif config["torch_dtype"] == "float32":
dt_ = DataType.INFINI_DTYPE_F32
elif config["torch_dtype"] == "bfloat16":
dt_ = DataType.INFINI_DTYPE_BF16
else:
dt_ = DataType.INFINI_DTYPE_F16
self.scale_input = 1.0
self.scale_output = 1.0
self.scale_o = 1.0
self.scale_down = 1.0
if (
config["model_type"] in ["fm9g", "minicpm"]
and "scale_emb" in config
and "scale_depth" in config
and "dim_model_base" in config
):
self.scale_input = config["scale_emb"]
self.scale_output = config["hidden_size"] // config["dim_model_base"]
self.scale_o = config["scale_depth"] / math.sqrt(
config["num_hidden_layers"]
)
self.scale_down = config["scale_depth"] / math.sqrt(
config["num_hidden_layers"]
)
has_qkv_bias = (
1 if "attention_bias" in config and config["attention_bias"] else 0
)
if config["model_type"] in ["qwen2", "qwen3"]:
has_qkv_bias = 1
eos_token_id = (
config["eos_token_id"][0]
if type(config["eos_token_id"]) == list
else config["eos_token_id"]
)
super().__init__(
dt_logits=dt_,
dt_linear_w=DataType.INFINI_DTYPE_I32,
dt_norm_w=dt_,
nlayer=config["num_hidden_layers"],
d=config["hidden_size"],
nh=config["num_attention_heads"],
nkvh=(
config["num_key_value_heads"]
if "num_key_value_heads" in config
else config["num_attention_heads"]
),
dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"],
dctx=(
config["max_position_embeddings"] if max_tokens is None else max_tokens
),
dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
end_token=eos_token_id,
nbit=config["quantization_config"]["bits"],
quant_group_size=config["quantization_config"]["group_size"],
has_qkv_bias=has_qkv_bias,
)
self.torch_dtype_logits = dtype
class JiugeAWQBatchedTask:
def __init__(self, tasks: List[InferTask]):
self.tasks = tasks
self.nreq = len(tasks)
# Precompute fields
token_lists = [t.tokens for t in tasks]
self.req_lens_list = [len(toks) for toks in token_lists]
self.req_pos_list = [t.pos for t in tasks]
self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
self.temperaturas_list = [t.temperature for t in tasks]
self.topks_list = [t.topk for t in tasks]
self.topps_list = [t.topp for t in tasks]
# Flatten token lists
flat_tokens = [tok for toks in token_lists for tok in toks]
self.ntok = len(flat_tokens)
# Convert to ctypes arrays in one pass
self.tokens = (c_uint * self.ntok)(*flat_tokens)
self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
self.topks = (c_uint * self.nreq)(*self.topks_list)
self.topps = (c_float * self.nreq)(*self.topps_list)
def input_args(self):
return (
self.tokens,
self.ntok,
self.req_lens,
self.nreq,
self.req_pos,
self.kv_caches,
self.temperaturas,
self.topks,
self.topps,
)
class JiugeAWQForCausalLM:
def __init__(
self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
):
load_start_time = time.time()
print(f"Creating model on {ndev} devices...")
with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f)
self.config = config
eos_token_id = self.config["eos_token_id"]
self.eos_token_id = (
[eos_token_id] if type(eos_token_id) == int else eos_token_id
)
self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
self.ndev = ndev
self.device = device
self.meta = JiugeAWQMetaFromConfig(config, max_tokens=max_tokens)
self.weights = create_jiuge_awq_weights(
self.meta,
self.device,
ndev,
self.dev_ids,
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
load_start_time = time.time()
print("Loading model weights to host...")
self.load_all_safetensors_from_dir(os.path.join(model_dir_path))
self.model_instance = create_jiuge_awq_model(
self.meta,
self.weights,
device,
ndev,
self.dev_ids,
)
load_end_time = time.time()
print(f"Time used: {load_end_time - load_start_time:.3f}s")
def load_all_safetensors_from_dir(self, dir_path_: str):
dir_path_ = Path(dir_path_)
for file in sorted(dir_path_.glob("*.safetensors")):
with safetensors.safe_open(file, framework="pt", device="cpu") as f:
for key in f.keys():
# print(key)
tensor = f.get_tensor(key)
if "proj" in key and "bias" not in key:
if "o_proj" in key or "down_proj" in key:
tensor = (
tensor.reshape(tensor.shape[0], self.ndev, -1)
.permute(1, 0, 2)
.contiguous()
)
if "o_proj.scales" in key:
tensor = tensor * self.meta.scale_o
elif "down_proj.scales" in key:
tensor = tensor * self.meta.scale_down
load_model_weight_distributed(
self.weights,
key,
tensor.data_ptr(),
self.dev_ids,
self.ndev,
)
else:
load_model_weight_distributed(
self.weights,
key,
tensor.data_ptr(),
self.dev_ids,
self.ndev,
)
else:
if "embed_tokens.weight" in key:
tensor = tensor * self.meta.scale_input
elif "lm_head.weight" in key:
tensor = tensor * self.meta.scale_output
load_model_weight(self.weights, key, tensor.data_ptr())
def max_context_len(self):
return self.meta.dctx
def create_kv_cache(self):
return create_kv_cache(
self.meta.nlayer,
self.meta.dctx,
self.meta.nkvh,
self.meta.dh,
self.meta.dh,
self.meta.dt_logits,
self.device,
self.dev_ids,
self.ndev,
)
def drop_kv_cache(self, kv_cache):
drop_kv_cache(kv_cache)
def batch_infer_one_round(self, tasks: List[InferTask]):
output = (c_uint * len(tasks))()
batch_inputs = JiugeAWQBatchedTask(tasks)
infer_batch_jiuge_awq(
self.model_instance,
*(batch_inputs.input_args()),
output,
)
return list(output)
def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0):
input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True,
tokenize=False,
)
print(input_content, end="", flush=True)
tokens = self.tokenizer.encode(input_content)
infer_task = InferTask(
0,
tokens,
self.max_context_len(),
temperature_,
topk_,
topp_,
self.eos_token_id,
)
infer_task.bind_kvcache(KVCache(self))
steps = 0
total_time = 0
output_content = ""
for step_i in range(max_steps):
start_time = time.time()
output_tokens = self.batch_infer_one_round([infer_task])
end_time = time.time()
steps += 1
# output_str = (
# self.tokenizer._tokenizer.id_to_token(output_tokens[0])
# .replace("▁", " ")
# .replace("<0x0A>", "\n")
# )
output_str = self.tokenizer.decode(output_tokens[0])
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
break
infer_task.next(output_tokens[0])
if step_i > 0:
total_time += end_time - start_time
print("\n")
avg_time = total_time * 1000 / (steps - 1)
print(f"Time per step: {avg_time:.3f}ms")
infer_task._kv_cache.drop(self)
return output_content, avg_time
def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
tasks = [
InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id)
for i in range(batch_size)
]
kv_caches = [KVCache(self) for _ in range(batch_size)]
nll = 0.0
total_len = 0
for i in range(0, len(test_sequences), batch_size):
batch_id = 0
true_tokens = []
while batch_id < batch_size and batch_id + i < len(test_sequences):
input_tokens = test_sequences[i + batch_id][:-1]
true_tokens.extend(test_sequences[i + batch_id][1:])
tasks[batch_id].tokens = input_tokens
tasks[batch_id].bind_kvcache(kv_caches[batch_id])
batch_id += 1
batch_inputs = JiugeAWQBatchedTask(tasks[:batch_id])
logits = torch.zeros(
(batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
)
forward_batch_jiuge_awq(
self.model_instance,
batch_inputs.tokens,
batch_inputs.ntok,
batch_inputs.req_lens,
batch_inputs.nreq,
batch_inputs.req_pos,
batch_inputs.kv_caches,
logits.data_ptr(),
)
logits = logits.float()
token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,]
log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab)
token_logprobs = log_probs[
torch.arange(batch_inputs.ntok), token_ids
] # (ntok,)
start = 0
for l in batch_inputs.req_lens_list:
nll += -token_logprobs[start : start + l].sum().item()
start += l
total_len += token_logprobs.numel()
for task in tasks:
task.release_kvcache()
return math.exp(nll / total_len)
def destroy_model_instance(self):
destroy_jiuge_awq_model(self.model_instance)
print("Model destroyed")
def test():
if len(sys.argv) < 3:
print(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
model_path = sys.argv[2]
device_type = DeviceType.DEVICE_TYPE_CPU
if sys.argv[1] == "--cpu":
device_type = DeviceType.DEVICE_TYPE_CPU
elif sys.argv[1] == "--nvidia":
device_type = DeviceType.DEVICE_TYPE_NVIDIA
elif sys.argv[1] == "--cambricon":
device_type = DeviceType.DEVICE_TYPE_CAMBRICON
elif sys.argv[1] == "--ascend":
device_type = DeviceType.DEVICE_TYPE_ASCEND
elif sys.argv[1] == "--metax":
device_type = DeviceType.DEVICE_TYPE_METAX
elif sys.argv[1] == "--moore":
device_type = DeviceType.DEVICE_TYPE_MOORE
elif sys.argv[1] == "--iluvatar":
device_type = DeviceType.DEVICE_TYPE_ILUVATAR
else:
print(
"Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
)
sys.exit(1)
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeAWQForCausalLM(model_path, device_type, ndev)
model.generate("山东最高的山是?", 500)
model.destroy_model_instance()
if __name__ == "__main__":
test()
from jiuge import JiugeForCauslLM
from jiuge_awq import JiugeAWQForCausalLM
from libinfinicore_infer import DeviceType
from infer_task import InferTask
from kvcache_pool import KVCachePool
......@@ -25,6 +26,7 @@ DEVICE_TYPE_MAP = {
"moore": DeviceType.DEVICE_TYPE_MOORE,
}
def parse_args():
parser = argparse.ArgumentParser(description="Launch the LLM inference server.")
parser.add_argument(
......@@ -58,19 +60,26 @@ def parse_args():
default=None,
help="Max token sequence length that model will handle (follows model config if not provided)",
)
parser.add_argument(
"--awq",
action="store_true",
help="Whether to use AWQ quantized model (default: False)",
)
return parser.parse_args()
args = parse_args()
device_type = DEVICE_TYPE_MAP[args.dev]
model_path = args.model_path
ndev = args.ndev
max_tokens = args.max_tokens
USE_AWQ = args.awq
MAX_BATCH = args.max_batch
print(
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
)
def chunk_json(id_, content=None, role=None, finish_reason=None):
delta = {}
if content:
......@@ -109,7 +118,14 @@ class AsyncInferTask(InferTask):
@contextlib.asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
app.state.model = JiugeForCauslLM(model_path, device_type, ndev, max_tokens=max_tokens)
if USE_AWQ:
app.state.model = JiugeAWQForCausalLM(
model_path, device_type, ndev, max_tokens=max_tokens
)
else:
app.state.model = JiugeForCauslLM(
model_path, device_type, ndev, max_tokens=max_tokens
)
app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH)
app.state.request_queue = janus.Queue()
worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True)
......@@ -277,6 +293,7 @@ async def chat_completions(request: Request):
response = await chat(id_, data, request)
return JSONResponse(content=response)
if __name__ == "__main__":
uvicorn.run(App, host="0.0.0.0", port=8000)
......
import ctypes
from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER
from ctypes import c_char, c_char_p, c_size_t, c_uint, c_int, c_float, c_void_p, POINTER
import os
......@@ -77,15 +77,188 @@ class JiugeModelCSruct(ctypes.Structure):
pass
class DeepSeekV3MetaCStruct(ctypes.Structure):
_fields_ = [
# dtypes
("dt_logits", DataType),
("dt_norm", DataType),
("dt_quant_weight", DataType),
("dt_quant_scale", DataType),
("dt_quant_zero", DataType),
("dt_gate_weight", DataType),
("dt_gate_bias", DataType),
# sizes
("n_sparse_layer", c_size_t),
("n_dense_layer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("d_rope", c_size_t),
("d_nope", c_size_t),
("r_q", c_size_t),
("r_kv", c_size_t),
("d_qk", c_size_t),
("d_v", c_size_t),
# routing / experts / vocab / ctx
("routed_scale", c_float),
("nexperts", c_size_t),
("kexperts", c_size_t),
("di", c_size_t),
("di_moe", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
# misc
("epsilon", c_float),
("rope_theta", c_float),
("end_token", c_uint),
]
class DeepSeekV3WeightsCStruct(ctypes.Structure):
pass
# void (*load_global_fn)(DeepSeekV3Weights*, void *cpu_ptr)
load_global_fn = ctypes.CFUNCTYPE(None, POINTER(DeepSeekV3WeightsCStruct), c_void_p)
# void (*load_layer_fn)(DeepSeekV3Weights*, void *cpu_ptr, size_t layer_id)
load_layer_fn = ctypes.CFUNCTYPE(
None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_size_t
)
# void (*load_layer_linear_fn)(DeepSeekV3Weights*, void *weight_ptr, void *scale_ptr, void *zero_ptr, size_t layer_id)
load_layer_linear_fn = ctypes.CFUNCTYPE(
None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_void_p, c_void_p, c_size_t
)
# void (*load_layer_mlp_fn)(DeepSeekV3Weights*, ... , size_t layer_id)
load_layer_mlp_fn = ctypes.CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
)
# void (*load_layer_expert_mlp_fn)(DeepSeekV3Weights*, ..., size_t layer_id, size_t expert_id)
load_layer_expert_mlp_fn = ctypes.CFUNCTYPE(
None,
POINTER(DeepSeekV3WeightsCStruct),
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_size_t,
c_size_t,
)
# -------------------------------------------------------------------
# Struct containing all weight loading functions
# -------------------------------------------------------------------
class DeepSeekV3WeightLoaderCStruct(ctypes.Structure):
_fields_ = [
# Global
("load_input_embd", load_global_fn),
("load_output_norm", load_global_fn),
("load_output_embd", load_global_fn),
# Attention
("load_attn_norm", load_layer_fn),
("load_attn_q_a_proj", load_layer_linear_fn),
("load_attn_q_a_layernorm", load_layer_fn),
("load_attn_q_b_proj", load_layer_linear_fn),
("load_attn_kv_a_proj_with_mqa", load_layer_linear_fn),
("load_attn_kv_a_layernorm", load_layer_fn),
("load_attn_kv_b_proj", load_layer_linear_fn),
("load_attn_o_proj", load_layer_linear_fn),
# MLP
("load_mlp_norm", load_layer_fn),
# MLP dense part
("load_mlp_dense", load_layer_mlp_fn),
# MLP sparse gating
("load_mlp_gate_weight", load_layer_fn),
("load_mlp_gate_bias", load_layer_fn),
# Shared experts
("load_mlp_shared_experts", load_layer_mlp_fn),
# Per-expert functions
("load_mlp_experts", load_layer_expert_mlp_fn),
]
class DeepSeekV3ModelCStruct(ctypes.Structure):
pass
class KVCacheCStruct(ctypes.Structure):
pass
class DeepSeekV3CacheCStruct(ctypes.Structure):
pass
class JiugeAWQMetaCStruct(ctypes.Structure):
_fields_ = [
("dt_logits", DataType),
("dt_linear_w", DataType),
("dt_norm_w", DataType),
("nlayer", c_size_t),
("d", c_size_t),
("nh", c_size_t),
("nkvh", c_size_t),
("dh", c_size_t),
("di", c_size_t),
("dctx", c_size_t),
("dvoc", c_size_t),
("epsilon", c_float),
("theta", c_float),
("end_token", c_uint),
("nbit", c_size_t),
("quant_group_size", c_size_t),
("has_qkv_bias", c_char),
]
class ModelWeightsCStruct(ctypes.Structure):
pass
class JiugeAWQModelCStruct(ctypes.Structure):
pass # opaque struct
def __open_library__():
lib_path = os.path.join(
os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so"
)
lib = ctypes.CDLL(lib_path)
lib.createKVCache.argtypes = [
c_size_t, # nlayers
c_size_t, # max_len
c_size_t, # nkvh_
c_size_t, # dk
c_size_t, # dv
DataType, # dtype
DeviceType, # device
POINTER(c_int), # dev_ids
c_size_t, # ndev
]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)]
lib.createJiugeModel.restype = POINTER(JiugeModelCSruct)
lib.createJiugeModel.argtypes = [
POINTER(JiugeMetaCStruct), # JiugeMeta const *
......@@ -95,11 +268,9 @@ def __open_library__():
POINTER(c_int), # int const *dev_ids
]
lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.argtypes = [POINTER(JiugeModelCSruct)]
lib.createKVCache.restype = POINTER(KVCacheCStruct)
lib.dropKVCache.argtypes = [POINTER(JiugeModelCSruct), POINTER(KVCacheCStruct)]
lib.inferBatch.restype = None
lib.inferBatch.argtypes = [
lib.inferBatchJiuge.restype = None
lib.inferBatchJiuge.argtypes = [
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
......@@ -112,8 +283,8 @@ def __open_library__():
POINTER(c_float), # float topp
POINTER(c_uint), # unsigned int *output
]
lib.forwardBatch.restype = None
lib.forwardBatch.argtypes = [
lib.forwardBatchJiuge.restype = None
lib.forwardBatchJiuge.argtypes = [
POINTER(JiugeModelCSruct), # struct JiugeModel const *
POINTER(c_uint), # unsigned int const *tokens
c_uint, # unsigned int ntok
......@@ -124,14 +295,164 @@ def __open_library__():
c_void_p, # void *logits
]
# createDeepSeekV3WeightLoader
lib.createDeepSeekV3WeightLoader.argtypes = []
lib.createDeepSeekV3WeightLoader.restype = POINTER(DeepSeekV3WeightLoaderCStruct)
lib.createDeepSeekV3Weights.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
DeviceType,
c_int,
POINTER(c_int),
]
lib.createDeepSeekV3Weights.restype = POINTER(DeepSeekV3WeightsCStruct)
lib.createDeepSeekV3Model.argtypes = [
POINTER(DeepSeekV3MetaCStruct),
POINTER(DeepSeekV3WeightsCStruct),
]
lib.createDeepSeekV3Model.restype = POINTER(DeepSeekV3ModelCStruct)
# destroyDeepSeekV3Model
lib.destroyDeepSeekV3Model.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.destroyDeepSeekV3Model.restype = None
# createDeepSeekV3Cache
lib.createDeepSeekV3Cache.argtypes = [POINTER(DeepSeekV3ModelCStruct)]
lib.createDeepSeekV3Cache.restype = POINTER(DeepSeekV3CacheCStruct)
# dropDeepSeekV3Cache
lib.dropDeepSeekV3Cache.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(DeepSeekV3CacheCStruct),
]
lib.dropDeepSeekV3Cache.restype = None
# inferBatchDeepSeekV3
lib.inferBatchDeepSeekV3.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(DeepSeekV3CacheCStruct)),
POINTER(c_float),
POINTER(c_uint),
POINTER(c_float),
POINTER(c_uint),
]
lib.inferBatchDeepSeekV3.restype = None
# forwardBatchDeepSeekV3
lib.forwardBatchDeepSeekV3.argtypes = [
POINTER(DeepSeekV3ModelCStruct),
POINTER(c_uint),
c_uint,
POINTER(c_uint),
c_uint,
POINTER(c_uint),
POINTER(POINTER(DeepSeekV3CacheCStruct)),
c_void_p,
]
lib.forwardBatchDeepSeekV3.restype = None
lib.createJiugeAWQWeights.restype = POINTER(ModelWeightsCStruct)
lib.createJiugeAWQWeights.argtypes = [
POINTER(JiugeAWQMetaCStruct), # const JiugeAWQMeta*
DeviceType, # infiniDevice_t
c_int, # int ndev
POINTER(c_int), # const int* dev_ids
]
# createJiugeAWQModel
lib.createJiugeAWQModel.restype = POINTER(JiugeAWQModelCStruct)
lib.createJiugeAWQModel.argtypes = [
POINTER(JiugeAWQMetaCStruct), # const JiugeAWQMeta*
POINTER(ModelWeightsCStruct), # const ModelWeights*
]
# destroyJiugeAWQModel
lib.destroyJiugeAWQModel.argtypes = [POINTER(JiugeAWQModelCStruct)]
lib.destroyJiugeAWQModel.restype = None
# inferBatchJiugeAWQ
lib.inferBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct), # JiugeAWQModel*
POINTER(c_uint), # const uint32_t* tokens
c_uint, # uint32_t ntok
POINTER(c_uint), # const uint32_t* req_lens
c_uint, # uint32_t nreq
POINTER(c_uint), # const uint32_t* req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache** kv_caches
POINTER(c_float), # const float* temperature
POINTER(c_uint), # const uint32_t* topk
POINTER(c_float), # const float* topp
POINTER(c_uint), # uint32_t* output
]
lib.inferBatchJiugeAWQ.restype = None
# forwardBatchJiugeAWQ
lib.forwardBatchJiugeAWQ.argtypes = [
POINTER(JiugeAWQModelCStruct), # JiugeAWQModel*
POINTER(c_uint), # const uint32_t* tokens
c_uint, # uint32_t ntok
POINTER(c_uint), # const uint32_t* req_lens
c_uint, # uint32_t nreq
POINTER(c_uint), # const uint32_t* req_pos
POINTER(POINTER(KVCacheCStruct)), # struct KVCache** kv_caches
c_void_p, # void* logits
]
lib.forwardBatchJiugeAWQ.restype = None
lib.loadModelWeight.argtypes = [
POINTER(ModelWeightsCStruct), # struct ModelWeights*
c_char_p, # const char* name
c_void_p, # void* data
]
lib.loadModelWeight.restype = None
# loadModelWeightDistributed
lib.loadModelWeightDistributed.argtypes = [
POINTER(ModelWeightsCStruct), # struct ModelWeights*
c_char_p, # const char* name
c_void_p, # void* data
POINTER(c_int), # int* ranks
c_int, # int nrank
]
lib.loadModelWeightDistributed.restype = None
return lib
LIB = __open_library__()
def load_model_weight(weights, name, data):
LIB.loadModelWeight(weights, name.encode("utf-8"), data)
def load_model_weight_distributed(weights, name, data, ranks, nrank):
LIB.loadModelWeightDistributed(weights, name.encode("utf-8"), data, ranks, nrank)
create_jiuge_model = LIB.createJiugeModel
destroy_jiuge_model = LIB.destroyJiugeModel
create_kv_cache = LIB.createKVCache
drop_kv_cache = LIB.dropKVCache
infer_batch = LIB.inferBatch
forward_batch = LIB.forwardBatch
infer_batch_jiuge = LIB.inferBatchJiuge
forward_batch_jiuge = LIB.forwardBatchJiuge
create_jiuge_awq_weights = LIB.createJiugeAWQWeights
create_jiuge_awq_model = LIB.createJiugeAWQModel
destroy_jiuge_awq_model = LIB.destroyJiugeAWQModel
infer_batch_jiuge_awq = LIB.inferBatchJiugeAWQ
forward_batch_jiuge_awq = LIB.forwardBatchJiugeAWQ
create_deepseek_v3_model = LIB.createDeepSeekV3Model
destroy_deepseek_v3_model = LIB.destroyDeepSeekV3Model
create_deepseek_v3_weight_loader = LIB.createDeepSeekV3WeightLoader
create_deepseek_v3_weights = LIB.createDeepSeekV3Weights
create_deepseek_v3_cache = LIB.createDeepSeekV3Cache
drop_deepseek_v3_cache = LIB.dropDeepSeekV3Cache
infer_batch_deepseek_v3 = LIB.inferBatchDeepSeekV3
#pragma once
#include "tensor.hpp"
#include <memory>
#include <vector>
struct KVCache {
std::vector<std::vector<std::shared_ptr<Tensor>>> k, v;
};
#include "jiuge_impl.hpp"
#include "../cache.hpp"
__C struct KVCache *createKVCache(
size_t nlayers,
size_t max_len,
size_t nkvh_,
size_t dk,
size_t dv,
infiniDtype_t dtype,
infiniDevice_t device,
int *dev_ids,
size_t ndev) {
__C struct KVCache *createKVCache(const JiugeModel *model) {
KVCache *cache = new KVCache();
auto ndev = model->dev_resources.size();
auto nkvh = model->meta.nkvh / ndev;
auto max_len = model->meta.dctx;
auto dh = model->meta.dh;
auto shape = std::vector<size_t>{max_len, nkvh, dh};
auto nkvh = nkvh_ / ndev;
auto shape_k = std::vector<size_t>{max_len, nkvh, dk};
auto shape_v = std::vector<size_t>{max_len, nkvh, dv};
for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
RUN_INFINI(infinirtSetDevice(device, dev_ids[idev]));
auto kcache = std::vector<std::shared_ptr<Tensor>>();
auto vcache = std::vector<std::shared_ptr<Tensor>>();
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
kcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape)));
vcache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, shape)));
for (unsigned int layer = 0; layer < nlayers; layer++) {
kcache.push_back(std::move(Tensor::buffer(dtype, shape_k)));
vcache.push_back(std::move(Tensor::buffer(dtype, shape_v)));
}
cache->k.push_back(kcache);
cache->v.push_back(vcache);
......@@ -22,35 +31,47 @@ __C struct KVCache *createKVCache(const JiugeModel *model) {
return cache;
}
__C struct KVCache *duplicateKVCache(const JiugeModel *model,
const KVCache *kv_cache,
unsigned int seq_len) {
auto new_kv_cache = createKVCache(model);
auto ndev = model->dev_resources.size();
auto nkvh = model->meta.nkvh / ndev;
auto dh = model->meta.dh;
auto dt_size = dsize(model->meta.dt_logits);
__C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) {
auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType();
auto dtype = kv_cache->k[0][0]->dtype();
auto shape_k = kv_cache->k[0][0]->shape();
auto shape_v = kv_cache->v[0][0]->shape();
auto size_k = seq_len * shape_k[1] * shape_k[2] * dsize(dtype);
auto size_v = seq_len * shape_v[1] * shape_v[2] * dsize(dtype);
KVCache *new_kv_cache = new KVCache();
for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
RUN_INFINI(infinirtSetDevice(device, kv_cache->k[idev][0]->deviceId()));
for (unsigned int layer = 0; layer < nlayers; layer++) {
auto kcache = std::vector<std::shared_ptr<Tensor>>();
auto vcache = std::vector<std::shared_ptr<Tensor>>();
for (unsigned int layer = 0; layer < nlayers; layer++) {
kcache.push_back(std::move(Tensor::buffer(dtype, shape_k)));
vcache.push_back(std::move(Tensor::buffer(dtype, shape_v)));
}
new_kv_cache->k.push_back(kcache);
new_kv_cache->v.push_back(vcache);
RUN_INFINI(infinirtMemcpy(new_kv_cache->k[idev][layer]->data(),
kv_cache->k[idev][layer]->data(),
seq_len * nkvh * dh * dt_size,
size_k,
INFINIRT_MEMCPY_D2D));
RUN_INFINI(infinirtMemcpy(new_kv_cache->v[idev][layer]->data(),
kv_cache->v[idev][layer]->data(),
seq_len * nkvh * dh * dt_size,
size_v,
INFINIRT_MEMCPY_D2D));
}
}
return new_kv_cache;
}
__C void dropKVCache(JiugeModel const *model, KVCache *kv_cache) {
auto ndev = model->dev_resources.size();
__C void dropKVCache(KVCache *kv_cache) {
auto ndev = kv_cache->k.size();
auto nlayers = kv_cache->k[0].size();
auto device = kv_cache->k[0][0]->deviceType();
for (unsigned int idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
for (unsigned int layer = 0; layer < model->meta.nlayer; layer++) {
RUN_INFINI(infinirtSetDevice(device, kv_cache->k[idev][0]->deviceId()));
for (unsigned int layer = 0; layer < nlayers; layer++) {
kv_cache->k[idev][layer].reset();
kv_cache->v[idev][layer].reset();
}
......
......@@ -145,10 +145,13 @@ private:
LRUDescriptorCache<infiniopRMSNormDescriptor_t> rms_norm_cache;
LRUDescriptorCache<infiniopGemmDescriptor_t> gemm_cache;
LRUDescriptorCache<infiniopRoPEDescriptor_t> rope_cache;
LRUDescriptorCache<infiniopRoPEv2Descriptor_t> rope_v2_cache;
LRUDescriptorCache<infiniopRearrangeDescriptor_t> rearrange_cache;
LRUDescriptorCache<infiniopCausalSoftmaxDescriptor_t> causal_softmax_cache;
LRUDescriptorCache<infiniopTopkrouterDescriptor_t> causal_topkrouter_cache;
LRUDescriptorCache<infiniopSwiGLUDescriptor_t> swiglu_cache;
LRUDescriptorCache<infiniopRandomSampleDescriptor_t> random_sample_cache;
LRUDescriptorCache<infiniopDequantizeDescriptor_t> dequantize_cache;
public:
CacheManager(size_t capacity = 100)
......@@ -156,10 +159,13 @@ public:
rms_norm_cache(capacity, infiniopDestroyRMSNormDescriptor),
gemm_cache(capacity, infiniopDestroyGemmDescriptor),
rope_cache(capacity, infiniopDestroyRoPEDescriptor),
rope_v2_cache(capacity, infiniopDestroyRoPEv2Descriptor),
rearrange_cache(capacity, infiniopDestroyRearrangeDescriptor),
causal_softmax_cache(capacity, infiniopDestroyCausalSoftmaxDescriptor),
causal_topkrouter_cache(capacity, infiniopDestroyTopkrouterDescriptor),
swiglu_cache(capacity, infiniopDestroySwiGLUDescriptor),
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor) {}
random_sample_cache(capacity, infiniopDestroyRandomSampleDescriptor),
dequantize_cache(capacity, infiniopDestroyDequantizeDescriptor) {}
// Add operations
bool getAddDescriptor(size_t key, infiniopAddDescriptor_t &desc) {
......@@ -197,6 +203,14 @@ public:
rope_cache.put(key, desc);
}
bool getRoPEv2Descriptor(size_t key, infiniopRoPEv2Descriptor_t &desc) {
return rope_v2_cache.get(key, desc);
}
void putRoPEv2Descriptor(size_t key, const infiniopRoPEv2Descriptor_t &desc) {
rope_v2_cache.put(key, desc);
}
// Rearrange operations
bool getRearrangeDescriptor(size_t key, infiniopRearrangeDescriptor_t &desc) {
return rearrange_cache.get(key, desc);
......@@ -215,6 +229,15 @@ public:
causal_softmax_cache.put(key, desc);
}
// Topkrouter operations
bool getTopkrouterDescriptor(size_t key, infiniopTopkrouterDescriptor_t &desc) {
return causal_topkrouter_cache.get(key, desc);
}
void putTopkrouterDescriptor(size_t key, const infiniopTopkrouterDescriptor_t &desc) {
causal_topkrouter_cache.put(key, desc);
}
// SwiGLU operations
bool getSwiGLUDescriptor(size_t key, infiniopSwiGLUDescriptor_t &desc) {
return swiglu_cache.get(key, desc);
......@@ -233,6 +256,15 @@ public:
random_sample_cache.put(key, desc);
}
// Dequantize operations
bool getDequantizeDescriptor(size_t key, infiniopDequantizeDescriptor_t &desc) {
return dequantize_cache.get(key, desc);
}
void putDequantizeDescriptor(size_t key, const infiniopDequantizeDescriptor_t &desc) {
dequantize_cache.put(key, desc);
}
template <typename... Tensors>
static size_t createDescriptorKey(Tensors... tensors) {
size_t seed = 0;
......
#include "weights_loader.hpp"
#include "infinicore_infer/weights_loader.h"
#include "../utils.hpp"
#include <infinirt.h>
namespace infinicore {
WeightsLoader::WeightsLoader(infiniDevice_t dev, const std::vector<int> &dev_ids) : _device(dev), _dev_ids(dev_ids) {
_streams.resize(_dev_ids.size());
_weights.resize(_dev_ids.size());
for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
_weights[rank] = std::unordered_map<std::string, std::shared_ptr<Tensor>>();
RUN_INFINI(infinirtStreamCreate(&_streams[rank]));
}
}
void WeightsLoader::resigter(const std::string &name, std::shared_ptr<Tensor> tensor, int rank) {
_weights[rank][name] = tensor;
}
void WeightsLoader::load_weight(const std::string &name, const void *host_data) {
for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
auto it = _weights[rank].find(name);
if (it == _weights[rank].end()) {
std::cerr << "Weight " << name << " not found in rank " << rank << std::endl;
std::abort();
}
_weights[rank][name]->load(host_data, _streams[rank]);
}
for (int rank = int(_dev_ids.size() - 1); rank >= 0; rank--) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
}
}
void WeightsLoader::load_distributed_weight(const std::string &name, const void *host_data, const std::vector<int> &ranks) {
for (size_t i = 0; i < ranks.size(); i++) {
int rank = ranks[i];
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
auto it = _weights[rank].find(name);
if (it == _weights[rank].end()) {
std::cerr << "Weight " << name << " not found in rank " << rank << std::endl;
std::abort();
}
_weights[rank][name]->load((char *)host_data + i * _weights[rank][name]->numel() * dsize(_weights[rank][name]->dtype()), _streams[rank]);
}
for (int rank = int(_dev_ids.size() - 1); rank >= 0; rank--) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
}
}
void WeightsLoader::load_rank_weight(const std::string &name, const void *host_data, int rank) {
auto it = _weights[rank].find(name);
if (it == _weights[rank].end()) {
std::cerr << "Weight " << name << " not found in rank " << rank << std::endl;
std::abort();
}
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
_weights[rank][name]->load(host_data);
}
void WeightsLoader::finalize() {
int dev_id;
RUN_INFINI(infinirtGetDevice(nullptr, &dev_id));
for (int rank = 0; rank < int(_dev_ids.size()); rank++) {
RUN_INFINI(infinirtSetDevice(_device, _dev_ids[rank]));
RUN_INFINI(infinirtStreamSynchronize(_streams[rank]));
RUN_INFINI(infinirtStreamDestroy(_streams[rank]));
}
RUN_INFINI(infinirtSetDevice(_device, dev_id));
}
std::shared_ptr<Tensor> WeightsLoader::get(const std::string &name, int rank) {
return _weights[rank][name];
}
} // namespace infinicore
__C void
loadModelWeight(struct ModelWeights *weights_, const char *name, void *data) {
std::string name_str(name);
// std::cout << "Loading weight: " << name_str << std::endl;
auto weights = reinterpret_cast<infinicore::WeightsLoader *>(weights_);
weights->load_weight(name_str, data);
}
__C void
loadModelWeightDistributed(struct ModelWeights *weights_, const char *name, void *data, int *ranks, int nrank) {
std::string name_str(name);
// std::cout << "Loading dist weight: " << name_str << std::endl;
auto weights = reinterpret_cast<infinicore::WeightsLoader *>(weights_);
std::vector<int> rank_vec(ranks, ranks + nrank);
weights->load_distributed_weight(name_str, data, rank_vec);
}
#ifndef WEIGHTS_LOADER_HPP
#define WEIGHTS_LOADER_HPP
#include "../tensor.hpp"
#include <unordered_map>
#include <vector>
namespace infinicore {
class WeightsLoader {
protected:
std::vector<std::unordered_map<std::string, std::shared_ptr<Tensor>>> _weights;
infiniDevice_t _device;
std::vector<int> _dev_ids;
std::vector<infinirtStream_t> _streams;
public:
WeightsLoader(infiniDevice_t, const std::vector<int> &dev_ids);
void resigter(const std::string &name, std::shared_ptr<Tensor> tensor, int rank = 0);
void load_weight(const std::string &name, const void *host_data);
void load_distributed_weight(const std::string &name, const void *host_data, const std::vector<int> &ranks);
void load_rank_weight(const std::string &name, const void *host_data, int rank);
void finalize();
std::shared_ptr<Tensor> get(const std::string &name, int rank = 0);
const std::vector<int> &dev_ids() const { return _dev_ids; }
infiniDevice_t device() const { return _device; }
};
} // namespace infinicore
#endif // WEIGHTS_LOADER_HPP
This diff is collapsed.
#include "deepseek_v3_impl.hpp"
__C struct DeepSeekV3Cache *
createDeepSeekV3Cache(const struct DeepSeekV3Model *model) {
DeepSeekV3Cache *cache = new DeepSeekV3Cache();
auto ndev = model->dev_resources.size();
auto nlayer = model->meta.n_dense_layer + model->meta.n_sparse_layer;
auto max_len = model->meta.dctx;
auto d_rope = model->meta.d_rope;
auto r_kv = model->meta.r_kv;
auto kv_pass_shape = std::vector<size_t>{max_len, r_kv};
auto k_rot_shape = std::vector<size_t>{max_len, d_rope};
for (size_t idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
auto kv_pass_cache = std::vector<std::shared_ptr<Tensor>>();
auto k_rot_cache = std::vector<std::shared_ptr<Tensor>>();
for (size_t layer = 0; layer < nlayer; layer++) {
kv_pass_cache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, kv_pass_shape)));
k_rot_cache.push_back(std::move(Tensor::buffer(model->meta.dt_logits, k_rot_shape)));
}
cache->kv_pass.push_back(kv_pass_cache);
cache->k_rot.push_back(k_rot_cache);
}
return cache;
}
__C void
dropDeepSeekV3Cache(const struct DeepSeekV3Model *model,
struct DeepSeekV3Cache *cache) {
auto ndev = model->dev_resources.size();
auto nlayer = model->meta.n_dense_layer + model->meta.n_sparse_layer;
for (size_t idev = 0; idev < ndev; idev++) {
RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[idev]));
for (size_t layer = 0; layer < nlayer; layer++) {
cache->kv_pass[idev][layer].reset();
cache->k_rot[idev][layer].reset();
}
}
delete cache;
}
\ No newline at end of file
#ifndef DEEPSEEK_V3_IMPL_H
#define DEEPSEEK_V3_IMPL_H
#include "infinicore_infer.h"
#include "../../allocator.hpp"
#include "../../tensor.hpp"
#include <condition_variable>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
struct QuantLinearWeight {
std::shared_ptr<Tensor> w;
std::shared_ptr<Tensor> s;
std::shared_ptr<Tensor> z;
};
struct MLAWeight {
std::shared_ptr<Tensor> kv_a_norm, q_a_norm;
std::shared_ptr<QuantLinearWeight> kv_a_proj, kv_b_proj, o_proj, q_a_proj, q_b_proj;
};
struct GateWeight {
std::shared_ptr<Tensor> w;
std::shared_ptr<Tensor> b;
};
struct MLPWeight {
std::shared_ptr<QuantLinearWeight> gate, up, down;
};
struct LayerWeight {
std::shared_ptr<Tensor> mla_norm;
std::shared_ptr<MLAWeight> mla;
std::shared_ptr<Tensor> mlp_norm;
std::shared_ptr<MLPWeight> dense_mlp;
std::shared_ptr<GateWeight> route;
std::shared_ptr<MLPWeight> share_expert;
std::vector<std::shared_ptr<MLPWeight>> experts;
};
struct DeepSeekV3DeviceWeights {
std::shared_ptr<Tensor> w_in_embd, w_out_norm, w_out_embd, sin_table,
cos_table;
std::vector<LayerWeight> w_layers;
infiniDevice_t device;
int dev_id;
infinirtStream_t load_stream;
};
struct DeepSeekV3Weights {
std::vector<std::shared_ptr<DeepSeekV3DeviceWeights>> device_weights;
DeepSeekV3Weights(const DeepSeekV3Meta *meta,
infiniDevice_t device,
int ndev,
const int *dev_ids);
};
struct DeepSeekV3DeviceResource {
// Device
infiniDevice_t device;
int device_id;
infiniopHandle_t handle;
// Weights
std::shared_ptr<DeepSeekV3DeviceWeights> weights;
// Streams
infinirtStream_t stream;
// Communicator
infinicclComm_t comm;
std::shared_ptr<MemoryPool> memory_pool;
};
struct InferState {
std::mutex mtx;
std::condition_variable cv_load, cv_start, cv_done;
bool loaded = false;
bool proceed = false;
bool exit_flag = false;
};
struct InferRequest {
const uint32_t *tokens;
uint32_t ntok;
const uint32_t *req_lens;
uint32_t nreq;
const uint32_t *req_pos;
struct DeepSeekV3Cache **kv_caches;
const float *temperature;
const uint32_t *topk;
const float *topp;
uint32_t *output;
void *logits;
};
struct DeepSeekV3Model {
DeepSeekV3Meta meta;
infiniDevice_t device;
std::vector<int> dev_ids;
std::vector<DeepSeekV3DeviceResource> dev_resources;
std::vector<InferState> states;
std::vector<std::thread> threads;
InferRequest req;
DeepSeekV3Model(const DeepSeekV3Meta *, const DeepSeekV3Weights *weights);
};
struct DeepSeekV3Cache {
std::vector<std::vector<std::shared_ptr<Tensor>>> kv_pass, k_rot;
};
#endif
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment