from typing import List, Sequence import math import os from pathlib import Path import safetensors import sys import time import json import torch import transformers from libinfinicore_infer import ( JiugeModel, JiugeMetaCStruct, JiugeWeightsCStruct, DataType, DeviceType, KVCacheCStruct, ) from infer_task import InferTask, KVCache from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref torch.set_default_device("cpu") class LlamaWeightsNaming: def input_embd(self): return "model.embed_tokens.weight" def output_norm(self): return "model.norm.weight" def output_embd(self): return "lm_head.weight" def attn_norm(self, i): return f"model.layers.{i}.input_layernorm.weight" def attn_q(self, i): return f"model.layers.{i}.self_attn.q_proj.weight" def attn_k(self, i): return f"model.layers.{i}.self_attn.k_proj.weight" def attn_v(self, i): return f"model.layers.{i}.self_attn.v_proj.weight" def attn_o(self, i): return f"model.layers.{i}.self_attn.o_proj.weight" def attn_q_b(self, i): return f"model.layers.{i}.self_attn.q_proj.bias" def attn_k_b(self, i): return f"model.layers.{i}.self_attn.k_proj.bias" def attn_v_b(self, i): return f"model.layers.{i}.self_attn.v_proj.bias" def attn_q_norm(self, i): return f"model.layers.{i}.self_attn.q_norm.weight" def attn_k_norm(self, i): return f"model.layers.{i}.self_attn.k_norm.weight" def ffn_norm(self, i): return f"model.layers.{i}.post_attention_layernorm.weight" def gate(self, i): return f"model.layers.{i}.mlp.gate_proj.weight" def up(self, i): return f"model.layers.{i}.mlp.up_proj.weight" def down(self, i): return f"model.layers.{i}.mlp.down_proj.weight" def match(state_dict): return ( "model.norm.weight" in state_dict and "model.layers.0.self_attn.q_proj.weight" in state_dict ) class JiugeMetaFromLlama(JiugeMetaCStruct): def __init__(self, config, dtype=torch.float16, max_tokens=None): if dtype == torch.float16: dt_ = DataType.INFINI_DTYPE_F16 elif dtype == torch.float32: dt_ = DataType.INFINI_DTYPE_F32 elif dtype == torch.bfloat16: dt_ = DataType.INFINI_DTYPE_BF16 else: dt_ = DataType.INFINI_DTYPE_F16 self.scale_input = 1.0 self.scale_output = 1.0 self.scale_o = 1.0 self.scale_down = 1.0 if ( config["model_type"] in ["fm9g", "minicpm"] and "scale_emb" in config and "scale_depth" in config and "dim_model_base" in config ): self.scale_input = config["scale_emb"] self.scale_output = config["hidden_size"] // config["dim_model_base"] self.scale_o = config["scale_depth"] / math.sqrt( config["num_hidden_layers"] ) self.scale_down = config["scale_depth"] / math.sqrt( config["num_hidden_layers"] ) super().__init__( dt_logits=dt_, nlayer=config["num_hidden_layers"], d=config["hidden_size"], nh=config["num_attention_heads"], nkvh=( config["num_key_value_heads"] if "num_key_value_heads" in config else config["num_attention_heads"] ), dh=( config["head_dim"] if "head_dim" in config else config["hidden_size"] // config["num_attention_heads"] ), di=config["intermediate_size"], dctx=( config["max_position_embeddings"] if max_tokens is None else max_tokens ), dvoc=config["vocab_size"], epsilon=config["rms_norm_eps"], theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), end_token=2, ) self.torch_dtype_logits = dtype class JiugeWeightsImpl(JiugeWeightsCStruct): def __init__( self, meta, naming, state_dict, torch_dt_mat=torch.float16, torch_dt_norm=torch.float32, ndev=1, transpose_weight=True, ): nlayer = meta.nlayer nh = meta.nh nkvh = meta.nkvh dh = meta.dh d = meta.d di = meta.di scale_input = meta.scale_input scale_output = meta.scale_output scale_o = meta.scale_o scale_down = meta.scale_down assert nh % nkvh == 0 assert nh % ndev == 0 assert nkvh % ndev == 0 assert di % ndev == 0 torch_dt_logits = meta.torch_dtype_logits if torch_dt_mat == torch.float16: self.dt_mat = DataType.INFINI_DTYPE_F16 elif torch_dt_mat == torch.float32: self.dt_mat = DataType.INFINI_DTYPE_F32 elif torch_dt_mat == torch.bfloat16: self.dt_mat = DataType.INFINI_DTYPE_BF16 else: raise ValueError("Unsupported proj weight data type") if torch_dt_norm == torch.float16: self.dt_norm = DataType.INFINI_DTYPE_F16 elif torch_dt_norm == torch.float32: self.dt_norm = DataType.INFINI_DTYPE_F32 elif torch_dt_norm == torch.bfloat16: self.dt_norm = DataType.INFINI_DTYPE_BF16 else: raise ValueError("Unsupported norm weight data type") input_embd_naming = ( naming.input_embd() if naming.input_embd() in state_dict else naming.output_embd() ) output_embd_naming = ( naming.output_embd() if naming.output_embd() in state_dict else naming.input_embd() ) self.transpose_linear_weights = 1 if transpose_weight else 0 self.nlayer = nlayer self.input_embd_tensor = ( state_dict[input_embd_naming].to(torch_dt_logits) * scale_input ) self.input_embd = self.input_embd_tensor.data_ptr() self.output_norm_tensor = ( state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output ) self.output_norm = self.output_norm_tensor.data_ptr() self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) if not transpose_weight: self.output_embd_tensor = self.output_embd_tensor.transpose( 0, 1 ).contiguous() self.output_embd = self.output_embd_tensor.data_ptr() self.attn_norm_tensors = [ state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer) ] self.attn_norm_ptrs = [ self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) ] self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) def qkv_slices(_i): _Q = ( state_dict[naming.attn_q(_i)] .reshape([nh, 2, dh // 2, d]) .transpose(1, 2) ) _K = ( state_dict[naming.attn_k(_i)] .reshape([nkvh, 2, dh // 2, d]) .transpose(1, 2) ) _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) _result = [] _nh = nh // ndev _nkvh = nkvh // ndev for _idev in range(ndev): _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) return _result self.qkv_tensor = [ torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer) ] if not transpose_weight: for i in range(nlayer): self.qkv_tensor[i] = ( self.qkv_tensor[i] .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d) .transpose(1, 2) .contiguous() ) self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) def qkv_b_slices(_i): _QB = ( state_dict[naming.attn_q_b(_i)] .reshape([nh, 2, dh // 2]) .transpose(1, 2) ) _KB = ( state_dict[naming.attn_k_b(_i)] .reshape([nkvh, 2, dh // 2]) .transpose(1, 2) ) _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2]) _result = [] _nh = nh // ndev _nkvh = nkvh // ndev for _idev in range(ndev): _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten()) _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) return _result if naming.attn_q_b(0) in state_dict: self.qkv_b_tensors = [ torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer) ] self.qkv_b_tensor_ptrs = [ self.qkv_b_tensors[i].data_ptr() for i in range(nlayer) ] self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs) else: self.attn_qkv_b = None if naming.attn_q_norm(0) in state_dict: self.attn_q_norm_tensors = [ state_dict[naming.attn_q_norm(i)] .reshape([2, dh // 2]) .transpose(0, 1) .contiguous() .to(torch_dt_norm) for i in range(nlayer) ] self.attn_q_norm_ptrs = [ self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer) ] self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs) self.attn_k_norm_tensors = [ state_dict[naming.attn_k_norm(i)] .reshape([2, dh // 2]) .transpose(0, 1) .contiguous() .to(torch_dt_norm) for i in range(nlayer) ] self.attn_k_norm_ptrs = [ self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer) ] self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs) else: self.attn_q_norm = None self.attn_k_norm = None self.attn_o_tensor = [ ( state_dict[naming.attn_o(i)] .to(torch_dt_mat) .reshape([d, ndev, nh // ndev * dh]) .transpose(0, 1) .contiguous() if transpose_weight else state_dict[naming.attn_o(i)] .transpose(0, 1) .to(torch_dt_mat) .contiguous() ) * scale_o for i in range(nlayer) ] self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) self.ffn_norm_tensors = [ state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer) ] self.ffn_norm_ptrs = [ self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) ] self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) def gate_up_slices(_i): _result = [] _di = di // ndev for _idev in range(ndev): _start = _idev * _di _end = (_idev + 1) * _di _result.append(state_dict[naming.gate(_i)][_start:_end, :]) _result.append(state_dict[naming.up(_i)][_start:_end, :]) return _result self.gate_up_tensors = [ torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer) ] if not transpose_weight: for i in range(nlayer): self.gate_up_tensors[i] = ( self.gate_up_tensors[i] .reshape(ndev, 2 * di // ndev, d) .transpose(1, 2) .contiguous() ) self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) self.ffn_down_tensor = [ ( state_dict[naming.down(i)] .to(torch_dt_mat) .reshape([d, ndev, di // ndev]) .transpose(0, 1) .contiguous() if transpose_weight else state_dict[naming.down(i)] .transpose(0, 1) .to(torch_dt_mat) .contiguous() ) * scale_down for i in range(nlayer) ] self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) class JiugeBatchedTask: def __init__(self, tasks: List[InferTask]): self.tasks = tasks self.nreq = len(tasks) # Precompute fields token_lists = [t.tokens for t in tasks] self.req_lens_list = [len(toks) for toks in token_lists] self.req_pos_list = [t.pos for t in tasks] self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] self.temperaturas_list = [t.temperature for t in tasks] self.topks_list = [t.topk for t in tasks] self.topps_list = [t.topp for t in tasks] # Flatten token lists flat_tokens = [tok for toks in token_lists for tok in toks] self.ntok = len(flat_tokens) # Convert to ctypes arrays in one pass self.tokens = (c_uint * self.ntok)(*flat_tokens) self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) self.topks = (c_uint * self.nreq)(*self.topks_list) self.topps = (c_float * self.nreq)(*self.topps_list) def input_args(self): return ( self.tokens, self.ntok, self.req_lens, self.nreq, self.req_pos, self.kv_caches, self.temperaturas, self.topks, self.topps, ) class JiugeForCauslLM: def __init__( self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None ): def load_all_safetensors_from_dir(dir_path_: str): tensors_ = {} dir_path_ = Path(dir_path_) for file in sorted(dir_path_.glob("*.safetensors")): data_ = safetensors.safe_open(file, "pt") for name_ in data_.keys(): tensors_[name_] = data_.get_tensor(name_) return tensors_ print("Loading model weights to host...") load_start_time = time.time() with open(os.path.join(model_dir_path, "config.json"), "r") as f: config = json.load(f) self.config = config eos_token_id = self.config["eos_token_id"] self.eos_token_id = ( [eos_token_id] if type(eos_token_id) == int else eos_token_id ) transpose_weight = ( device != DeviceType.DEVICE_TYPE_ASCEND ) # y = xW is faster than y=xW^T on Ascend self.jiuge_model = JiugeModel() if "llama" == config["model_type"]: model = ( transformers.LlamaForCausalLM.from_pretrained(model_dir_path) .cpu() .half() ) self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), model.state_dict(), ndev=ndev, transpose_weight=transpose_weight, ) elif "fm9g" == config["model_type"] or "minicpm" == config["model_type"]: if any( file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() ): state_dict = load_all_safetensors_from_dir(model_dir_path) else: state_dict = torch.load( os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu", ) if LlamaWeightsNaming.match(state_dict): self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight, ) self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_dir_path, trust_remote_code=True ) else: raise ValueError("Unsupported weight naming") elif "fm9g7b" == config["model_type"]: if any( file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() ): state_dict = load_all_safetensors_from_dir(model_dir_path) else: state_dict = torch.load( os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu", ) if LlamaWeightsNaming.match(state_dict): self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight, ) self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_dir_path, trust_remote_code=True ) else: raise ValueError("Unsupported weight naming") elif "qwen2" == config["model_type"] or "qwen3" == config["model_type"]: state_dict = load_all_safetensors_from_dir(model_dir_path) if LlamaWeightsNaming.match(state_dict): self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) self.weights = JiugeWeightsImpl( self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev, transpose_weight=transpose_weight, ) self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_dir_path ) else: raise ValueError("Unsupported model architecture") if "llama" == config["model_type"]: from tokenizers import decoders as _dec backend = getattr(self.tokenizer, "backend_tokenizer", None) target = getattr(backend, "_tokenizer", backend) norm = getattr(target, "normalizer", None) dec = getattr(target, "decoder", None) sn = repr(norm)[:800] if norm is not None else "" sd = repr(dec)[:800] if dec is not None else "" has_prepend = "Prepend" in sn has_strip = "Strip" in sd if has_prepend and has_strip: target.decoder = _dec.Sequence( [ _dec.Replace("▁", " "), _dec.ByteFallback(), _dec.Fuse(), ] ) load_end_time = time.time() print(f"Time used: {load_end_time - load_start_time:.3f}s") print(f"Creating model on {ndev} devices...") load_start_time = time.time() self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) self.ndev = ndev self.device = device self.model_instance = self.jiuge_model.create_model( byref(self.meta), byref(self.weights), device, ndev, self.dev_ids, ) load_end_time = time.time() print(f"Time used: {load_end_time - load_start_time:.3f}s") def max_context_len(self): return self.meta.dctx def create_kv_cache(self): return self.jiuge_model.create_kv_cache( self.meta.nlayer, self.meta.dctx, self.meta.nkvh, self.meta.dh, self.meta.dh, self.meta.dt_logits, self.device, self.dev_ids, self.ndev, ) def drop_kv_cache(self, kv_cache): self.jiuge_model.drop_kv_cache(kv_cache) def batch_infer_one_round(self, tasks: List[InferTask]): output = (c_uint * len(tasks))() batch_inputs = JiugeBatchedTask(tasks) self.jiuge_model.infer_batch( self.model_instance, *(batch_inputs.input_args()), output, ) return list(output) def generate( self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0, verbose=False, ): input_content = self.tokenizer.apply_chat_template( conversation=[{"role": "user", "content": input_content}], add_generation_prompt=True, tokenize=False, ) print(input_content, end="", flush=True) tokens = self.tokenizer.encode(input_content) infer_task = InferTask( 0, tokens, self.max_context_len(), temperature_, topk_, topp_, self.eos_token_id, ) infer_task.bind_kvcache(KVCache(self)) steps = 0 total_time = 0 prefill_time = 0 decode_time = 0 output_content = "" # Prefill phase - process initial prompt prefill_start_time = time.time() output_tokens = self.batch_infer_one_round([infer_task]) prefill_end_time = time.time() prefill_time = prefill_end_time - prefill_start_time steps += 1 output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: # If generation ends after prefill, calculate metrics total_time = prefill_time total_tokens = len(tokens) + 1 # input tokens + first output token print("\n") print(f"Time per step: {total_time * 1000:.3f}ms") if verbose: overall_throughput = total_tokens / total_time prefill_throughput = len(tokens) / prefill_time decode_throughput = 1 / 0.001 # Avoid division by zero, use small value print("=" * 50) print("PERFORMANCE METRICS") print("=" * 50) print(f"Input tokens: {len(tokens)}") print(f"Generated tokens: 1") print(f"Total tokens: {total_tokens}") print(f"Total time: {total_time * 1000:.3f}ms") print(f"Prefill time: {prefill_time * 1000:.3f}ms") print(f"Decode time: 0.000ms") print("-" * 50) print(f"Time per step: {total_time * 1000:.3f}ms") print( f"Avg prefill time per token: {prefill_time * 1000 / len(tokens):.3f}ms" ) print(f"Avg decode time per token: N/A") print("-" * 50) print(f"Overall throughput: {overall_throughput:.2f} tokens/s") print(f"Prefill throughput: {prefill_throughput:.2f} tokens/s") print(f"Decode throughput: N/A") print("=" * 50) return output_content, total_time * 1000 infer_task.next(output_tokens[0]) # Decode phase - generate subsequent tokens decode_start_time = time.time() for step_i in range(1, max_steps): start_time = time.time() output_tokens = self.batch_infer_one_round([infer_task]) end_time = time.time() steps += 1 output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: break infer_task.next(output_tokens[0]) if step_i > 0: total_time += end_time - start_time decode_end_time = time.time() decode_time = decode_end_time - decode_start_time print("\n") # Calculate performance metrics total_time = prefill_time + decode_time input_tokens = len(tokens) generated_tokens = steps # including first token from prefill # Time per token calculations avg_time_per_step = ( total_time * 1000 / (steps - 1) if steps > 1 else total_time * 1000 ) print(f"Time per step: {avg_time_per_step:.3f}ms") # Only print detailed metrics if verbose flag is set if verbose: total_tokens = input_tokens + generated_tokens # Throughput calculations overall_throughput = total_tokens / total_time # tokens per second prefill_throughput = input_tokens / prefill_time if prefill_time > 0 else 0 decode_throughput = ( (generated_tokens - 1) / decode_time if decode_time > 0 else 0 ) # exclude first token from prefill # Time per token calculations avg_prefill_time_per_token = ( prefill_time * 1000 / input_tokens if input_tokens > 0 else 0 ) avg_decode_time_per_token = ( decode_time * 1000 / (generated_tokens - 1) if generated_tokens > 1 else 0 ) print("=" * 50) print("PERFORMANCE METRICS") print("=" * 50) print(f"Input tokens: {input_tokens}") print(f"Generated tokens: {generated_tokens}") print(f"Total tokens: {total_tokens}") print(f"Total time: {total_time * 1000:.3f}ms") print(f"Prefill time: {prefill_time * 1000:.3f}ms") print(f"Decode time: {decode_time * 1000:.3f}ms") print("-" * 50) print(f"Time per step: {avg_time_per_step:.3f}ms") print(f"Avg prefill time per token: {avg_prefill_time_per_token:.3f}ms") print(f"Avg decode time per token: {avg_decode_time_per_token:.3f}ms") print("-" * 50) print(f"Overall throughput: {overall_throughput:.2f} tokens/s") print(f"Prefill throughput: {prefill_throughput:.2f} tokens/s") print(f"Decode throughput: {decode_throughput:.2f} tokens/s") print("=" * 50) infer_task._kv_cache.drop(self) return output_content, avg_time_per_step def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): tasks = [ InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) for i in range(batch_size) ] kv_caches = [KVCache(self) for _ in range(batch_size)] nll = 0.0 total_len = 0 for i in range(0, len(test_sequences), batch_size): batch_id = 0 true_tokens = [] while batch_id < batch_size and batch_id + i < len(test_sequences): input_tokens = test_sequences[i + batch_id][:-1] true_tokens.extend(test_sequences[i + batch_id][1:]) tasks[batch_id].tokens = input_tokens tasks[batch_id].bind_kvcache(kv_caches[batch_id]) batch_id += 1 batch_inputs = JiugeBatchedTask(tasks[:batch_id]) logits = torch.zeros( (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits ) self.jiuge_model.forward_batch( self.model_instance, batch_inputs.tokens, batch_inputs.ntok, batch_inputs.req_lens, batch_inputs.nreq, batch_inputs.req_pos, batch_inputs.kv_caches, logits.data_ptr(), ) logits = logits.float() token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) token_logprobs = log_probs[ torch.arange(batch_inputs.ntok), token_ids ] # (ntok,) start = 0 for l in batch_inputs.req_lens_list: nll += -token_logprobs[start : start + l].sum().item() start += l total_len += token_logprobs.numel() for task in tasks: task.release_kvcache() return math.exp(nll / total_len) def destroy_model_instance(self): self.jiuge_model.destroy_model(self.model_instance) print("Model destroyed") def test(): if len(sys.argv) < 3: print( "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device] [--verbose]" ) sys.exit(1) # Parse command line arguments model_path = sys.argv[2] device_type = DeviceType.DEVICE_TYPE_CPU verbose = False # Check for verbose flag for arg in sys.argv: if arg == "--verbose": verbose = True break if sys.argv[1] == "--cpu": device_type = DeviceType.DEVICE_TYPE_CPU elif sys.argv[1] == "--nvidia": device_type = DeviceType.DEVICE_TYPE_NVIDIA elif sys.argv[1] == "--cambricon": device_type = DeviceType.DEVICE_TYPE_CAMBRICON elif sys.argv[1] == "--ascend": device_type = DeviceType.DEVICE_TYPE_ASCEND elif sys.argv[1] == "--metax": device_type = DeviceType.DEVICE_TYPE_METAX elif sys.argv[1] == "--moore": device_type = DeviceType.DEVICE_TYPE_MOORE elif sys.argv[1] == "--iluvatar": device_type = DeviceType.DEVICE_TYPE_ILUVATAR elif sys.argv[1] == "--kunlun": device_type = DeviceType.DEVICE_TYPE_KUNLUN elif sys.argv[1] == "--hygon": device_type = DeviceType.DEVICE_TYPE_HYGON else: print( "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device] [--verbose]" ) sys.exit(1) # Find n_device argument (skip --verbose) ndev_args = [arg for arg in sys.argv[3:] if arg != "--verbose"] ndev = int(ndev_args[0]) if ndev_args else 1 model = JiugeForCauslLM(model_path, device_type, ndev) model.generate("山东最高的山是?", 500, verbose=verbose) model.destroy_model_instance() if __name__ == "__main__": test()