Commit 366d3aef authored by Pan Zezhong's avatar Pan Zezhong
Browse files

support 9g7b

parent c3d5efa5
from ctypes import POINTER, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
from libinfinicore_infer import (
JiugeMeta,
......@@ -18,6 +20,7 @@ from libinfinicore_infer import (
import torch
import transformers
torch.set_default_device("cpu")
class LlamaWeightsNaming:
def input_embd(self):
......@@ -73,8 +76,8 @@ class LlamaWeightsNaming:
class JiugeMetaFromLlama(JiugeMeta):
def __init__(self, config, dtype = torch.float16):
if dtype == torch.float16:
def __init__(self, config, dtype=torch.float16):
if dtype == torch.float16:
dt_ = DataType.INFINI_DTYPE_F16
elif dtype == torch.float32:
dt_ = DataType.INFINI_DTYPE_F32
......@@ -82,27 +85,35 @@ class JiugeMetaFromLlama(JiugeMeta):
dt_ = DataType.INFINI_DTYPE_F16
super().__init__(
dt_logits=dt_,
nlayer=config.num_hidden_layers,
d=config.hidden_size,
nh=config.num_attention_heads,
nlayer=config["num_hidden_layers"],
d=config["hidden_size"],
nh=config["num_attention_heads"],
nkvh=(
config.num_key_value_heads
if config.num_key_value_heads
else config.num_attention_heads
config["num_key_value_heads"]
if "num_key_value_heads" in config
else config["num_attention_heads"]
),
dh=config.hidden_size // config.num_attention_heads,
di=config.intermediate_size,
dctx=config.max_position_embeddings,
dvoc=config.vocab_size,
epsilon=config.rms_norm_eps,
theta=config.rope_theta,
dh=config["hidden_size"] // config["num_attention_heads"],
di=config["intermediate_size"],
dctx=config["max_position_embeddings"],
dvoc=config["vocab_size"],
epsilon=config["rms_norm_eps"],
theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
end_token=2,
)
self.torch_dtype_logits = dtype
class JiugeWeightsImpl(JiugeWeights):
def __init__(self, meta, naming, state_dict, torch_dt_mat = torch.float16, torch_dt_norm = torch.float32, ndev=1):
def __init__(
self,
meta,
naming,
state_dict,
torch_dt_mat=torch.float16,
torch_dt_norm=torch.float32,
ndev=1,
):
nlayer = meta.nlayer
nh = meta.nh
nkvh = meta.nkvh
......@@ -127,12 +138,23 @@ class JiugeWeightsImpl(JiugeWeights):
else:
raise ValueError("Unsupported norm weight data type")
input_embd_naming = (
naming.input_embd()
if naming.input_embd() in state_dict
else naming.output_embd()
)
output_embd_naming = (
naming.output_embd()
if naming.output_embd() in state_dict
else naming.input_embd()
)
self.nlayer = nlayer
self.input_embd_tensor = state_dict[naming.input_embd()].to(torch_dt_logits)
self.input_embd_tensor = state_dict[input_embd_naming].to(torch_dt_logits)
self.input_embd = self.input_embd_tensor.data_ptr()
self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_norm)
self.output_norm = self.output_norm_tensor.data_ptr()
self.output_embd_tensor = state_dict[naming.output_embd()].to(torch_dt_mat)
self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
self.output_embd = self.output_embd_tensor.data_ptr()
self.attn_norm_tensors = [
......@@ -164,7 +186,9 @@ class JiugeWeightsImpl(JiugeWeights):
_result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
return _result
self.qkv_tensor = [torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)]
self.qkv_tensor = [
torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)
]
self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)
......@@ -184,13 +208,15 @@ class JiugeWeightsImpl(JiugeWeights):
_nh = nh // ndev
_nkvh = nkvh // ndev
for _idev in range(ndev):
_result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :])
_result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
_result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
_result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten())
_result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
_result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
return _result
if naming.attn_q_b(0) in state_dict:
self.qkv_b_tensors = [torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer)]
self.qkv_b_tensors = [
torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer)
]
self.qkv_b_tensor_ptrs = [
self.qkv_b_tensors[i].data_ptr() for i in range(nlayer)
]
......@@ -199,7 +225,8 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_qkv_b = None
self.attn_o_tensor = [
state_dict[naming.attn_o(i)].to(torch_dt_mat)
state_dict[naming.attn_o(i)]
.to(torch_dt_mat)
.reshape([d, ndev, nh // ndev * dh])
.transpose(0, 1)
.contiguous()
......@@ -208,7 +235,9 @@ class JiugeWeightsImpl(JiugeWeights):
self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs)
self.ffn_norm_tensors = [state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer)]
self.ffn_norm_tensors = [
state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
]
self.ffn_norm_ptrs = [
self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer)
]
......@@ -224,12 +253,15 @@ class JiugeWeightsImpl(JiugeWeights):
_result.append(state_dict[naming.up(_i)][_start:_end, :])
return _result
self.gate_up_tensors = [torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)]
self.gate_up_tensors = [
torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)
]
self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
self.ffn_down_tensor = [
state_dict[naming.down(i)].to(torch_dt_mat)
state_dict[naming.down(i)]
.to(torch_dt_mat)
.reshape([d, ndev, di // ndev])
.transpose(0, 1)
.contiguous()
......@@ -250,17 +282,17 @@ class JiugeForCauslLM:
tensors_[name_] = data_.get_tensor(name_)
return tensors_
config = transformers.AutoConfig.from_pretrained(
model_dir_path, trust_remote_code=True
)
if "llama" == config.model_type:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).half()
self.meta = JiugeMetaFromLlama(model.config)
with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f)
if "llama" == config["model_type"]:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half()
self.meta = JiugeMetaFromLlama(config)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev
)
elif "fm9g" == config.model_type:
elif "fm9g" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
......@@ -270,6 +302,19 @@ class JiugeForCauslLM:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
elif "fm9g7b" == config["model_type"]:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu"
)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
else:
raise ValueError("Unsupported model architecture")
dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
......@@ -285,6 +330,11 @@ class JiugeForCauslLM:
pass
def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0):
input_content = self.tokenizer.apply_chat_template(
conversation=[{"role": "user", "content": input_content}],
add_generation_prompt=True,
tokenize=False,
)
print(input_content, end="", flush=True)
kv_cache = create_kv_cache(self.model_instance)
tokens = self.tokenizer.encode(input_content)
......@@ -298,8 +348,10 @@ class JiugeForCauslLM:
ans = (c_uint * nreq)()
steps = 0
start_time = time.time()
for _ in range(max_steps):
total_time = 0
for step_i in range(max_steps):
start_time = time.time()
infer_batch(
self.model_instance,
tokens,
......@@ -324,15 +376,16 @@ class JiugeForCauslLM:
break
output_content += output_str
print(output_str, end="", flush=True)
# print(output_tokens[0])
req_pos[0] = req_pos[0] + ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
req_lens = (c_uint * nreq)(*[ntok])
end_time = time.time()
if step_i > 0:
total_time += end_time - start_time
print("\n")
end_time = time.time()
avg_time = (end_time - start_time) * 1000 / steps
avg_time = total_time * 1000 / (steps - 1)
print(f"Time per step: {avg_time:.3f}ms")
for kv_cache in kv_caches:
drop_kv_cache(self.model_instance, kv_cache)
......@@ -367,7 +420,7 @@ def test():
ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
model = JiugeForCauslLM(model_path, device_type, ndev)
model.generate("Once upon a time,", 100)
model.generate("山东最高的山是?", 500)
if __name__ == "__main__":
......
......@@ -243,12 +243,12 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
if (has_qkv_bias) {
RUN_INFINI(infiniopRearrange(
desc_qkv_bias,
qkv_buf->data(), rsrc.b_attn_qkv.data(), stream));
qkv_buf->data(), rsrc.b_attn_qkv[layer]->data(), stream));
}
RUN_INFINI(infiniopGemm(
desc_attn_qkv, workspace, workspace_size,
qkv_buf->data(), logits_out->data(),
rsrc.w_attn_qkv[layer]->data(), 1.0, 0.0, stream));
rsrc.w_attn_qkv[layer]->data(), 1.0, has_qkv_bias ? 1.0 : 0.0, stream));
// rope
RUN_INFINI(infiniopRoPE(
desc_rope_q, workspace, workspace_size,
......
......@@ -56,7 +56,7 @@ inline std::shared_ptr<Tensor> getAttnQKVBias(
auto nh = meta->nh;
auto dh = meta->dh;
size_t offset = idev * ((nkvh * 2 + nh) / ndev * dh) * dsize(w->dt_mat);
auto shape = std::vector<size_t>({1, (nh + 2 * nkvh) / ndev * dh});
auto shape = std::vector<size_t>({(nh + 2 * nkvh) / ndev * dh});
return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment