from ctypes import POINTER, c_int, c_uint, c_void_p, byref from pathlib import Path import safetensors import sys import time from libinfinicore_infer import ( JiugeMeta, JiugeWeights, KVCache, DataType, DeviceType, create_jiuge_model, create_kv_cache, drop_kv_cache, infer_batch, ) import torch import transformers class LlamaWeightsNaming: def input_embd(self): return "model.embed_tokens.weight" def output_norm(self): return "model.norm.weight" def output_embd(self): return "lm_head.weight" def attn_norm(self, i): return f"model.layers.{i}.input_layernorm.weight" def attn_q(self, i): return f"model.layers.{i}.self_attn.q_proj.weight" def attn_k(self, i): return f"model.layers.{i}.self_attn.k_proj.weight" def attn_v(self, i): return f"model.layers.{i}.self_attn.v_proj.weight" def attn_o(self, i): return f"model.layers.{i}.self_attn.o_proj.weight" def attn_q_b(self, i): return f"model.layers.{i}.self_attn.q_proj.bias" def attn_k_b(self, i): return f"model.layers.{i}.self_attn.k_proj.bias" def attn_v_b(self, i): return f"model.layers.{i}.self_attn.v_proj.bias" def ffn_norm(self, i): return f"model.layers.{i}.post_attention_layernorm.weight" def gate(self, i): return f"model.layers.{i}.mlp.gate_proj.weight" def up(self, i): return f"model.layers.{i}.mlp.up_proj.weight" def down(self, i): return f"model.layers.{i}.mlp.down_proj.weight" def match(state_dict): return ( "model.norm.weight" in state_dict and "model.layers.0.self_attn.q_proj.weight" in state_dict ) class JiugeMetaFromLlama(JiugeMeta): def __init__(self, config, dtype = torch.float16): if dtype == torch.float16: dt_ = DataType.INFINI_DTYPE_F16 elif dtype == torch.float32: dt_ = DataType.INFINI_DTYPE_F32 else: dt_ = DataType.INFINI_DTYPE_F16 super().__init__( dt_logits=dt_, nlayer=config.num_hidden_layers, d=config.hidden_size, nh=config.num_attention_heads, nkvh=( config.num_key_value_heads if config.num_key_value_heads else config.num_attention_heads ), dh=config.hidden_size // config.num_attention_heads, di=config.intermediate_size, dctx=config.max_position_embeddings, dvoc=config.vocab_size, epsilon=config.rms_norm_eps, theta=config.rope_theta, end_token=2, ) self.torch_dtype_logits = dtype class JiugeWeightsImpl(JiugeWeights): def __init__(self, meta, naming, state_dict, torch_dt_mat = torch.float16, torch_dt_norm = torch.float32, ndev=1): nlayer = meta.nlayer nh = meta.nh nkvh = meta.nkvh dh = meta.dh d = meta.d di = meta.di assert nh % nkvh == 0 assert nh % ndev == 0 assert nkvh % ndev == 0 assert di % ndev == 0 torch_dt_logits = meta.torch_dtype_logits if torch_dt_mat == torch.float16: self.dt_mat = DataType.INFINI_DTYPE_F16 elif torch_dt_mat == torch.float32: self.dt_mat = DataType.INFINI_DTYPE_F32 else: raise ValueError("Unsupported proj weight data type") if torch_dt_norm == torch.float16: self.dt_norm = DataType.INFINI_DTYPE_F16 elif torch_dt_norm == torch.float32: self.dt_norm = DataType.INFINI_DTYPE_F32 else: raise ValueError("Unsupported norm weight data type") self.nlayer = nlayer self.input_embd_tensor = state_dict[naming.input_embd()].to(torch_dt_logits) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_norm) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[naming.output_embd()].to(torch_dt_mat) self.output_embd = self.output_embd_tensor.data_ptr() self.attn_norm_tensors = [ state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer) ] self.attn_norm_ptrs = [ self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) ] self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) def qkv_slices(_i): _Q = ( state_dict[naming.attn_q(_i)] .reshape([nh, 2, dh // 2, d]) .transpose(1, 2) ) _K = ( state_dict[naming.attn_k(_i)] .reshape([nkvh, 2, dh // 2, d]) .transpose(1, 2) ) _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) _result = [] _nh = nh // ndev _nkvh = nkvh // ndev for _idev in range(ndev): _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) return _result self.qkv_tensor = [torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)] self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) def qkv_b_slices(_i): _QB = ( state_dict[naming.attn_q_b(_i)] .reshape([nh, 2, dh // 2]) .transpose(1, 2) ) _KB = ( state_dict[naming.attn_k_b(_i)] .reshape([nkvh, 2, dh // 2]) .transpose(1, 2) ) _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2]) _result = [] _nh = nh // ndev _nkvh = nkvh // ndev for _idev in range(ndev): _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :]) _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) return _result if naming.attn_q_b(0) in state_dict: self.qkv_b_tensors = [torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer)] self.qkv_b_tensor_ptrs = [ self.qkv_b_tensors[i].data_ptr() for i in range(nlayer) ] self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs) else: self.attn_qkv_b = None self.attn_o_tensor = [ state_dict[naming.attn_o(i)].to(torch_dt_mat) .reshape([d, ndev, nh // ndev * dh]) .transpose(0, 1) .contiguous() for i in range(nlayer) ] self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) self.ffn_norm_tensors = [state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer)] self.ffn_norm_ptrs = [ self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) ] self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) def gate_up_slices(_i): _result = [] _di = di // ndev for _idev in range(ndev): _start = _idev * _di _end = (_idev + 1) * _di _result.append(state_dict[naming.gate(_i)][_start:_end, :]) _result.append(state_dict[naming.up(_i)][_start:_end, :]) return _result self.gate_up_tensors = [torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)] self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) self.ffn_down_tensor = [ state_dict[naming.down(i)].to(torch_dt_mat) .reshape([d, ndev, di // ndev]) .transpose(0, 1) .contiguous() for i in range(nlayer) ] self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) class JiugeForCauslLM: def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1): def load_all_safetensors_from_dir(dir_path_: str): tensors_ = {} dir_path_ = Path(dir_path_) for file in sorted(dir_path_.glob("*.safetensors")): data_ = safetensors.safe_open(file, "pt") for name_ in data_.keys(): tensors_[name_] = data_.get_tensor(name_) return tensors_ config = transformers.AutoConfig.from_pretrained( model_dir_path, trust_remote_code=True ) if "llama" == config.model_type: model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).half() self.meta = JiugeMetaFromLlama(model.config) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev ) elif "fm9g" == config.model_type: state_dict = load_all_safetensors_from_dir(model_dir_path) if LlamaWeightsNaming.match(state_dict): self.meta = JiugeMetaFromLlama(config) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev ) self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_dir_path, trust_remote_code=True ) else: raise ValueError("Unsupported model architecture") dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) self.model_instance = create_jiuge_model( byref(self.meta), byref(self.weights), device, ndev, dev_ids, ) def infer(self, input_list, topp=1.0, topk=1, temperature=1.0): pass def generate(self, input_content, max_steps, topp=1.0, topk=1, temperature=1.0): print(input_content, end="", flush=True) kv_cache = create_kv_cache(self.model_instance) tokens = self.tokenizer.encode(input_content) ntok = len(tokens) nreq = 1 output_content = "" tokens = (c_uint * ntok)(*tokens) req_lens = (c_uint * nreq)(*[ntok]) req_pos = (c_uint * nreq)(*[0]) kv_caches = (POINTER(KVCache) * nreq)(*[kv_cache]) ans = (c_uint * nreq)() steps = 0 start_time = time.time() for _ in range(max_steps): infer_batch( self.model_instance, tokens, ntok, req_lens, nreq, req_pos, kv_caches, ans, temperature, topk, topp, ) steps += 1 output_tokens = list(ans) output_str = ( self.tokenizer._tokenizer.id_to_token(output_tokens[0]) .replace("▁", " ") .replace("<0x0A>", "\n") ) if output_str.endswith(""): break output_content += output_str print(output_str, end="", flush=True) # print(output_tokens[0]) req_pos[0] = req_pos[0] + ntok ntok = 1 tokens = (c_uint * ntok)(*output_tokens) req_lens = (c_uint * nreq)(*[ntok]) print("\n") end_time = time.time() avg_time = (end_time - start_time) * 1000 / steps print(f"Time per step: {avg_time:.3f}ms") for kv_cache in kv_caches: drop_kv_cache(self.model_instance, kv_cache) return output_content, avg_time def test(): if len(sys.argv) < 3: print( "Usage: python test_llama.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" ) sys.exit(1) model_path = sys.argv[2] device_type = DeviceType.DEVICE_TYPE_CPU if sys.argv[1] == "--cpu": device_type = DeviceType.DEVICE_TYPE_CPU elif sys.argv[1] == "--nvidia": device_type = DeviceType.DEVICE_TYPE_NVIDIA elif sys.argv[1] == "--cambricon": device_type = DeviceType.DEVICE_TYPE_CAMBRICON elif sys.argv[1] == "--ascend": device_type = DeviceType.DEVICE_TYPE_ASCEND elif sys.argv[1] == "--metax": device_type = DeviceType.DEVICE_TYPE_METAX elif sys.argv[1] == "--moore": device_type = DeviceType.DEVICE_TYPE_MOORE else: print( "Usage: python test_llama.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" ) sys.exit(1) ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 model = JiugeForCauslLM(model_path, device_type, ndev) model.generate("Once upon a time,", 100) if __name__ == "__main__": test()