jiuge.py 31.3 KB
Newer Older
PanZezhong's avatar
PanZezhong committed
1
from typing import List, Sequence
2
3
4
5
6
7
8
9
10
import math
import os
from pathlib import Path
import safetensors
import sys
import time
import json
import torch
import transformers
PanZezhong's avatar
PanZezhong committed
11

PanZezhong's avatar
PanZezhong committed
12
from libinfinicore_infer import (
13
    JiugeModel,
14
15
    JiugeMetaCStruct,
    JiugeWeightsCStruct,
PanZezhong's avatar
PanZezhong committed
16
17
    DataType,
    DeviceType,
18
    KVCacheCStruct,
PanZezhong's avatar
PanZezhong committed
19
)
20
from infer_task import InferTask, KVCache
Pan Zezhong's avatar
Pan Zezhong committed
21
22

from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
PanZezhong's avatar
PanZezhong committed
23

Pan Zezhong's avatar
Pan Zezhong committed
24
torch.set_default_device("cpu")
PanZezhong's avatar
PanZezhong committed
25

26

PanZezhong's avatar
PanZezhong committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class LlamaWeightsNaming:
    def input_embd(self):
        return "model.embed_tokens.weight"

    def output_norm(self):
        return "model.norm.weight"

    def output_embd(self):
        return "lm_head.weight"

    def attn_norm(self, i):
        return f"model.layers.{i}.input_layernorm.weight"

    def attn_q(self, i):
        return f"model.layers.{i}.self_attn.q_proj.weight"

    def attn_k(self, i):
        return f"model.layers.{i}.self_attn.k_proj.weight"

    def attn_v(self, i):
        return f"model.layers.{i}.self_attn.v_proj.weight"

    def attn_o(self, i):
        return f"model.layers.{i}.self_attn.o_proj.weight"

    def attn_q_b(self, i):
        return f"model.layers.{i}.self_attn.q_proj.bias"

    def attn_k_b(self, i):
        return f"model.layers.{i}.self_attn.k_proj.bias"

    def attn_v_b(self, i):
        return f"model.layers.{i}.self_attn.v_proj.bias"

61
62
63
64
65
66
    def attn_q_norm(self, i):
        return f"model.layers.{i}.self_attn.q_norm.weight"

    def attn_k_norm(self, i):
        return f"model.layers.{i}.self_attn.k_norm.weight"

PanZezhong's avatar
PanZezhong committed
67
68
69
    def ffn_norm(self, i):
        return f"model.layers.{i}.post_attention_layernorm.weight"

PanZezhong's avatar
PanZezhong committed
70
71
72
73
74
75
76
77
78
    def gate(self, i):
        return f"model.layers.{i}.mlp.gate_proj.weight"

    def up(self, i):
        return f"model.layers.{i}.mlp.up_proj.weight"

    def down(self, i):
        return f"model.layers.{i}.mlp.down_proj.weight"

PanZezhong's avatar
PanZezhong committed
79
80
81
82
83
84
    def match(state_dict):
        return (
            "model.norm.weight" in state_dict
            and "model.layers.0.self_attn.q_proj.weight" in state_dict
        )

PanZezhong's avatar
PanZezhong committed
85

86
class JiugeMetaFromLlama(JiugeMetaCStruct):
87
    def __init__(self, config, dtype=torch.float16, max_tokens=None):
Pan Zezhong's avatar
Pan Zezhong committed
88
        if dtype == torch.float16:
PanZezhong's avatar
PanZezhong committed
89
90
91
            dt_ = DataType.INFINI_DTYPE_F16
        elif dtype == torch.float32:
            dt_ = DataType.INFINI_DTYPE_F32
PanZezhong's avatar
PanZezhong committed
92
93
        elif dtype == torch.bfloat16:
            dt_ = DataType.INFINI_DTYPE_BF16
PanZezhong's avatar
PanZezhong committed
94
95
        else:
            dt_ = DataType.INFINI_DTYPE_F16
96

97
98
99
100
        self.scale_input = 1.0
        self.scale_output = 1.0
        self.scale_o = 1.0
        self.scale_down = 1.0
PanZezhong's avatar
PanZezhong committed
101
        if (
102
            config["model_type"] in ["fm9g", "minicpm"]
PanZezhong's avatar
PanZezhong committed
103
104
105
106
            and "scale_emb" in config
            and "scale_depth" in config
            and "dim_model_base" in config
        ):
107
108
109
110
111
112
113
114
            self.scale_input = config["scale_emb"]
            self.scale_output = config["hidden_size"] // config["dim_model_base"]
            self.scale_o = config["scale_depth"] / math.sqrt(
                config["num_hidden_layers"]
            )
            self.scale_down = config["scale_depth"] / math.sqrt(
                config["num_hidden_layers"]
            )
115

PanZezhong's avatar
PanZezhong committed
116
        super().__init__(
PanZezhong's avatar
PanZezhong committed
117
            dt_logits=dt_,
Pan Zezhong's avatar
Pan Zezhong committed
118
119
120
            nlayer=config["num_hidden_layers"],
            d=config["hidden_size"],
            nh=config["num_attention_heads"],
PanZezhong's avatar
PanZezhong committed
121
            nkvh=(
Pan Zezhong's avatar
Pan Zezhong committed
122
123
124
                config["num_key_value_heads"]
                if "num_key_value_heads" in config
                else config["num_attention_heads"]
PanZezhong's avatar
PanZezhong committed
125
            ),
wooway777's avatar
wooway777 committed
126
127
128
129
130
            dh=(
                config["head_dim"]
                if "head_dim" in config
                else config["hidden_size"] // config["num_attention_heads"]
            ),
Pan Zezhong's avatar
Pan Zezhong committed
131
            di=config["intermediate_size"],
132
133
134
            dctx=(
                config["max_position_embeddings"] if max_tokens is None else max_tokens
            ),
Pan Zezhong's avatar
Pan Zezhong committed
135
136
137
            dvoc=config["vocab_size"],
            epsilon=config["rms_norm_eps"],
            theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
PanZezhong's avatar
PanZezhong committed
138
139
            end_token=2,
        )
PanZezhong's avatar
PanZezhong committed
140
        self.torch_dtype_logits = dtype
PanZezhong's avatar
PanZezhong committed
141
142


143
class JiugeWeightsImpl(JiugeWeightsCStruct):
Pan Zezhong's avatar
Pan Zezhong committed
144
145
146
147
148
149
150
151
    def __init__(
        self,
        meta,
        naming,
        state_dict,
        torch_dt_mat=torch.float16,
        torch_dt_norm=torch.float32,
        ndev=1,
PanZezhong's avatar
PanZezhong committed
152
        transpose_weight=True,
Pan Zezhong's avatar
Pan Zezhong committed
153
    ):
PanZezhong's avatar
PanZezhong committed
154
155
156
157
158
159
        nlayer = meta.nlayer
        nh = meta.nh
        nkvh = meta.nkvh
        dh = meta.dh
        d = meta.d
        di = meta.di
160
161
162
163
        scale_input = meta.scale_input
        scale_output = meta.scale_output
        scale_o = meta.scale_o
        scale_down = meta.scale_down
PanZezhong's avatar
PanZezhong committed
164
165
166
167
        assert nh % nkvh == 0
        assert nh % ndev == 0
        assert nkvh % ndev == 0
        assert di % ndev == 0
PanZezhong's avatar
PanZezhong committed
168
169
170
171
172
        torch_dt_logits = meta.torch_dtype_logits
        if torch_dt_mat == torch.float16:
            self.dt_mat = DataType.INFINI_DTYPE_F16
        elif torch_dt_mat == torch.float32:
            self.dt_mat = DataType.INFINI_DTYPE_F32
PanZezhong's avatar
PanZezhong committed
173
174
        elif torch_dt_mat == torch.bfloat16:
            self.dt_mat = DataType.INFINI_DTYPE_BF16
PanZezhong's avatar
PanZezhong committed
175
176
177
178
179
180
        else:
            raise ValueError("Unsupported proj weight data type")
        if torch_dt_norm == torch.float16:
            self.dt_norm = DataType.INFINI_DTYPE_F16
        elif torch_dt_norm == torch.float32:
            self.dt_norm = DataType.INFINI_DTYPE_F32
PanZezhong's avatar
PanZezhong committed
181
182
        elif torch_dt_norm == torch.bfloat16:
            self.dt_norm = DataType.INFINI_DTYPE_BF16
PanZezhong's avatar
PanZezhong committed
183
184
        else:
            raise ValueError("Unsupported norm weight data type")
PanZezhong's avatar
PanZezhong committed
185

Pan Zezhong's avatar
Pan Zezhong committed
186
187
188
189
190
191
192
193
194
195
        input_embd_naming = (
            naming.input_embd()
            if naming.input_embd() in state_dict
            else naming.output_embd()
        )
        output_embd_naming = (
            naming.output_embd()
            if naming.output_embd() in state_dict
            else naming.input_embd()
        )
PanZezhong's avatar
PanZezhong committed
196
        self.transpose_linear_weights = 1 if transpose_weight else 0
PanZezhong's avatar
PanZezhong committed
197
        self.nlayer = nlayer
198
199
200
        self.input_embd_tensor = (
            state_dict[input_embd_naming].to(torch_dt_logits) * scale_input
        )
PanZezhong's avatar
PanZezhong committed
201
        self.input_embd = self.input_embd_tensor.data_ptr()
202
203
204
        self.output_norm_tensor = (
            state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output
        )
PanZezhong's avatar
PanZezhong committed
205
        self.output_norm = self.output_norm_tensor.data_ptr()
Pan Zezhong's avatar
Pan Zezhong committed
206
        self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
PanZezhong's avatar
PanZezhong committed
207
        if not transpose_weight:
208
209
210
            self.output_embd_tensor = self.output_embd_tensor.transpose(
                0, 1
            ).contiguous()
PanZezhong's avatar
PanZezhong committed
211
212
213
        self.output_embd = self.output_embd_tensor.data_ptr()

        self.attn_norm_tensors = [
PanZezhong's avatar
PanZezhong committed
214
            state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
PanZezhong's avatar
PanZezhong committed
215
216
217
218
219
        ]
        self.attn_norm_ptrs = [
            self.attn_norm_tensors[i].data_ptr() for i in range(nlayer)
        ]
        self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs)
PanZezhong's avatar
PanZezhong committed
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

        def qkv_slices(_i):
            _Q = (
                state_dict[naming.attn_q(_i)]
                .reshape([nh, 2, dh // 2, d])
                .transpose(1, 2)
            )
            _K = (
                state_dict[naming.attn_k(_i)]
                .reshape([nkvh, 2, dh // 2, d])
                .transpose(1, 2)
            )
            _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d])
            _result = []
            _nh = nh // ndev
            _nkvh = nkvh // ndev
            for _idev in range(ndev):
                _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :])
                _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :])
                _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
            return _result

Pan Zezhong's avatar
Pan Zezhong committed
242
243
244
        self.qkv_tensor = [
            torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
245
246
        if not transpose_weight:
            for i in range(nlayer):
247
248
249
250
251
252
                self.qkv_tensor[i] = (
                    self.qkv_tensor[i]
                    .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d)
                    .transpose(1, 2)
                    .contiguous()
                )
PanZezhong's avatar
PanZezhong committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
        self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)

        def qkv_b_slices(_i):
            _QB = (
                state_dict[naming.attn_q_b(_i)]
                .reshape([nh, 2, dh // 2])
                .transpose(1, 2)
            )
            _KB = (
                state_dict[naming.attn_k_b(_i)]
                .reshape([nkvh, 2, dh // 2])
                .transpose(1, 2)
            )
            _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2])
            _result = []
            _nh = nh // ndev
            _nkvh = nkvh // ndev
            for _idev in range(ndev):
Pan Zezhong's avatar
Pan Zezhong committed
272
273
274
                _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten())
                _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
                _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
PanZezhong's avatar
PanZezhong committed
275
276
277
            return _result

        if naming.attn_q_b(0) in state_dict:
Pan Zezhong's avatar
Pan Zezhong committed
278
279
280
            self.qkv_b_tensors = [
                torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer)
            ]
PanZezhong's avatar
PanZezhong committed
281
282
283
284
285
286
287
            self.qkv_b_tensor_ptrs = [
                self.qkv_b_tensors[i].data_ptr() for i in range(nlayer)
            ]
            self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs)
        else:
            self.attn_qkv_b = None

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        if naming.attn_q_norm(0) in state_dict:
            self.attn_q_norm_tensors = [
                state_dict[naming.attn_q_norm(i)]
                .reshape([2, dh // 2])
                .transpose(0, 1)
                .contiguous()
                .to(torch_dt_norm)
                for i in range(nlayer)
            ]
            self.attn_q_norm_ptrs = [
                self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer)
            ]
            self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs)
            self.attn_k_norm_tensors = [
                state_dict[naming.attn_k_norm(i)]
                .reshape([2, dh // 2])
                .transpose(0, 1)
                .contiguous()
                .to(torch_dt_norm)
                for i in range(nlayer)
            ]
            self.attn_k_norm_ptrs = [
                self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer)
            ]
            self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs)
        else:
            self.attn_q_norm = None
            self.attn_k_norm = None

PanZezhong's avatar
PanZezhong committed
317
        self.attn_o_tensor = [
318
319
            (
                state_dict[naming.attn_o(i)]
PanZezhong's avatar
PanZezhong committed
320
321
322
323
                .to(torch_dt_mat)
                .reshape([d, ndev, nh // ndev * dh])
                .transpose(0, 1)
                .contiguous()
324
325
326
327
328
329
                if transpose_weight
                else state_dict[naming.attn_o(i)]
                .transpose(0, 1)
                .to(torch_dt_mat)
                .contiguous()
            )
330
            * scale_o
PanZezhong's avatar
PanZezhong committed
331
332
            for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
333
334
335
        self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
        self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs)

Pan Zezhong's avatar
Pan Zezhong committed
336
337
338
        self.ffn_norm_tensors = [
            state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
339
340
341
342
        self.ffn_norm_ptrs = [
            self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer)
        ]
        self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs)
PanZezhong's avatar
PanZezhong committed
343
344
345
346
347
348
349
350
351
352
353

        def gate_up_slices(_i):
            _result = []
            _di = di // ndev
            for _idev in range(ndev):
                _start = _idev * _di
                _end = (_idev + 1) * _di
                _result.append(state_dict[naming.gate(_i)][_start:_end, :])
                _result.append(state_dict[naming.up(_i)][_start:_end, :])
            return _result

Pan Zezhong's avatar
Pan Zezhong committed
354
355
356
        self.gate_up_tensors = [
            torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
357
358
        if not transpose_weight:
            for i in range(nlayer):
359
360
361
362
363
364
                self.gate_up_tensors[i] = (
                    self.gate_up_tensors[i]
                    .reshape(ndev, 2 * di // ndev, d)
                    .transpose(1, 2)
                    .contiguous()
                )
PanZezhong's avatar
PanZezhong committed
365
366
        self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
        self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
PanZezhong's avatar
PanZezhong committed
367
368

        self.ffn_down_tensor = [
369
370
371
372
373
374
375
376
377
378
379
380
            (
                state_dict[naming.down(i)]
                .to(torch_dt_mat)
                .reshape([d, ndev, di // ndev])
                .transpose(0, 1)
                .contiguous()
                if transpose_weight
                else state_dict[naming.down(i)]
                .transpose(0, 1)
                .to(torch_dt_mat)
                .contiguous()
            )
381
            * scale_down
PanZezhong's avatar
PanZezhong committed
382
383
            for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
384
385
        self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
        self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs)
PanZezhong's avatar
PanZezhong committed
386
387


Pan Zezhong's avatar
Pan Zezhong committed
388
389
390
391
392
393
394
395
396
class JiugeBatchedTask:
    def __init__(self, tasks: List[InferTask]):
        self.tasks = tasks
        self.nreq = len(tasks)

        # Precompute fields
        token_lists = [t.tokens for t in tasks]
        self.req_lens_list = [len(toks) for toks in token_lists]
        self.req_pos_list = [t.pos for t in tasks]
397
        self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
Pan Zezhong's avatar
Pan Zezhong committed
398
399
400
401
402
403
404
405
406
407
408
409
        self.temperaturas_list = [t.temperature for t in tasks]
        self.topks_list = [t.topk for t in tasks]
        self.topps_list = [t.topp for t in tasks]

        # Flatten token lists
        flat_tokens = [tok for toks in token_lists for tok in toks]
        self.ntok = len(flat_tokens)

        # Convert to ctypes arrays in one pass
        self.tokens = (c_uint * self.ntok)(*flat_tokens)
        self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
        self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
410
        self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
Pan Zezhong's avatar
Pan Zezhong committed
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
        self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
        self.topks = (c_uint * self.nreq)(*self.topks_list)
        self.topps = (c_float * self.nreq)(*self.topps_list)

    def input_args(self):
        return (
            self.tokens,
            self.ntok,
            self.req_lens,
            self.nreq,
            self.req_pos,
            self.kv_caches,
            self.temperaturas,
            self.topks,
            self.topps,
        )


PanZezhong's avatar
PanZezhong committed
429
class JiugeForCauslLM:
430
431
432
    def __init__(
        self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
    ):
PanZezhong's avatar
PanZezhong committed
433
        def load_all_safetensors_from_dir(dir_path_: str):
PanZezhong's avatar
PanZezhong committed
434
435
436
437
438
            tensors_ = {}
            dir_path_ = Path(dir_path_)
            for file in sorted(dir_path_.glob("*.safetensors")):
                data_ = safetensors.safe_open(file, "pt")
                for name_ in data_.keys():
PanZezhong's avatar
PanZezhong committed
439
                    tensors_[name_] = data_.get_tensor(name_)
PanZezhong's avatar
PanZezhong committed
440
            return tensors_
441

PanZezhong's avatar
PanZezhong committed
442
443
        print("Loading model weights to host...")
        load_start_time = time.time()
PanZezhong's avatar
PanZezhong committed
444

Pan Zezhong's avatar
Pan Zezhong committed
445
446
        with open(os.path.join(model_dir_path, "config.json"), "r") as f:
            config = json.load(f)
PanZezhong's avatar
PanZezhong committed
447
448
            self.config = config
        eos_token_id = self.config["eos_token_id"]
449
450
451
452
453
454
        self.eos_token_id = (
            [eos_token_id] if type(eos_token_id) == int else eos_token_id
        )
        transpose_weight = (
            device != DeviceType.DEVICE_TYPE_ASCEND
        )  # y = xW is faster than y=xW^T on Ascend
455
456
457

        self.jiuge_model = JiugeModel()

Pan Zezhong's avatar
Pan Zezhong committed
458
        if "llama" == config["model_type"]:
459
460
461
462
463
            model = (
                transformers.LlamaForCausalLM.from_pretrained(model_dir_path)
                .cpu()
                .half()
            )
464
            self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
PanZezhong's avatar
PanZezhong committed
465
466
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
            self.weights = JiugeWeightsImpl(
467
468
469
470
471
                self.meta,
                LlamaWeightsNaming(),
                model.state_dict(),
                ndev=ndev,
                transpose_weight=transpose_weight,
PanZezhong's avatar
PanZezhong committed
472
            )
473
        elif "fm9g" == config["model_type"] or "minicpm" == config["model_type"]:
474
475
476
            if any(
                file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
            ):
PanZezhong's avatar
PanZezhong committed
477
478
479
                state_dict = load_all_safetensors_from_dir(model_dir_path)
            else:
                state_dict = torch.load(
480
481
482
483
                    os.path.join(model_dir_path, "pytorch_model.bin"),
                    weights_only=True,
                    map_location="cpu",
                )
PanZezhong's avatar
PanZezhong committed
484
            if LlamaWeightsNaming.match(state_dict):
485
                self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
PanZezhong's avatar
PanZezhong committed
486
                self.weights = JiugeWeightsImpl(
487
488
489
490
491
                    self.meta,
                    LlamaWeightsNaming(),
                    state_dict,
                    ndev=ndev,
                    transpose_weight=transpose_weight,
PanZezhong's avatar
PanZezhong committed
492
493
494
495
                )
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_dir_path, trust_remote_code=True
                )
PanZezhong's avatar
PanZezhong committed
496
497
            else:
                raise ValueError("Unsupported weight naming")
Pan Zezhong's avatar
Pan Zezhong committed
498
        elif "fm9g7b" == config["model_type"]:
Pan Zezhong's avatar
Pan Zezhong committed
499
500
501
502
503
504
505
506
507
508
            if any(
                file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
            ):
                state_dict = load_all_safetensors_from_dir(model_dir_path)
            else:
                state_dict = torch.load(
                    os.path.join(model_dir_path, "pytorch_model.bin"),
                    weights_only=True,
                    map_location="cpu",
                )
Pan Zezhong's avatar
Pan Zezhong committed
509
            if LlamaWeightsNaming.match(state_dict):
510
                self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
Pan Zezhong's avatar
Pan Zezhong committed
511
                self.weights = JiugeWeightsImpl(
512
513
514
515
516
                    self.meta,
                    LlamaWeightsNaming(),
                    state_dict,
                    ndev=ndev,
                    transpose_weight=transpose_weight,
Pan Zezhong's avatar
Pan Zezhong committed
517
518
519
520
                )
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_dir_path, trust_remote_code=True
                )
PanZezhong's avatar
PanZezhong committed
521
522
            else:
                raise ValueError("Unsupported weight naming")
523
        elif "qwen2" == config["model_type"] or "qwen3" == config["model_type"]:
PanZezhong's avatar
PanZezhong committed
524
525
            state_dict = load_all_safetensors_from_dir(model_dir_path)
            if LlamaWeightsNaming.match(state_dict):
526
                self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
PanZezhong's avatar
PanZezhong committed
527
                self.weights = JiugeWeightsImpl(
528
529
530
531
532
                    self.meta,
                    LlamaWeightsNaming(),
                    state_dict,
                    ndev=ndev,
                    transpose_weight=transpose_weight,
PanZezhong's avatar
PanZezhong committed
533
534
535
536
                )
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_dir_path
                )
PanZezhong's avatar
PanZezhong committed
537
538
        else:
            raise ValueError("Unsupported model architecture")
PanZezhong's avatar
PanZezhong committed
539

540
541
        if "llama" == config["model_type"]:
            from tokenizers import decoders as _dec
wooway777's avatar
wooway777 committed
542

543
544
545
546
547
548
549
550
551
            backend = getattr(self.tokenizer, "backend_tokenizer", None)
            target = getattr(backend, "_tokenizer", backend)
            norm = getattr(target, "normalizer", None)
            dec = getattr(target, "decoder", None)
            sn = repr(norm)[:800] if norm is not None else ""
            sd = repr(dec)[:800] if dec is not None else ""
            has_prepend = "Prepend" in sn
            has_strip = "Strip" in sd
            if has_prepend and has_strip:
wooway777's avatar
wooway777 committed
552
553
554
555
556
557
558
                target.decoder = _dec.Sequence(
                    [
                        _dec.Replace("▁", " "),
                        _dec.ByteFallback(),
                        _dec.Fuse(),
                    ]
                )
559

PanZezhong's avatar
PanZezhong committed
560
561
        load_end_time = time.time()
        print(f"Time used: {load_end_time - load_start_time:.3f}s")
562

PanZezhong's avatar
PanZezhong committed
563
564
        print(f"Creating model on {ndev} devices...")
        load_start_time = time.time()
blkmjsian's avatar
blkmjsian committed
565
566
567
        self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
        self.ndev = ndev
        self.device = device
568
569

        self.model_instance = self.jiuge_model.create_model(
PanZezhong's avatar
PanZezhong committed
570
571
572
573
            byref(self.meta),
            byref(self.weights),
            device,
            ndev,
blkmjsian's avatar
blkmjsian committed
574
            self.dev_ids,
PanZezhong's avatar
PanZezhong committed
575
        )
PanZezhong's avatar
PanZezhong committed
576
577
        load_end_time = time.time()
        print(f"Time used: {load_end_time - load_start_time:.3f}s")
578

Pan Zezhong's avatar
Pan Zezhong committed
579
580
    def max_context_len(self):
        return self.meta.dctx
581

Pan Zezhong's avatar
Pan Zezhong committed
582
    def create_kv_cache(self):
583
        return self.jiuge_model.create_kv_cache(
blkmjsian's avatar
blkmjsian committed
584
585
586
587
588
589
590
591
592
593
            self.meta.nlayer,
            self.meta.dctx,
            self.meta.nkvh,
            self.meta.dh,
            self.meta.dh,
            self.meta.dt_logits,
            self.device,
            self.dev_ids,
            self.ndev,
        )
Pan Zezhong's avatar
Pan Zezhong committed
594
595

    def drop_kv_cache(self, kv_cache):
596
        self.jiuge_model.drop_kv_cache(kv_cache)
Pan Zezhong's avatar
Pan Zezhong committed
597

Pan Zezhong's avatar
Pan Zezhong committed
598
599
600
    def batch_infer_one_round(self, tasks: List[InferTask]):
        output = (c_uint * len(tasks))()
        batch_inputs = JiugeBatchedTask(tasks)
601
        self.jiuge_model.infer_batch(
Pan Zezhong's avatar
Pan Zezhong committed
602
603
604
            self.model_instance,
            *(batch_inputs.input_args()),
            output,
Pan Zezhong's avatar
Pan Zezhong committed
605
        )
Pan Zezhong's avatar
Pan Zezhong committed
606
        return list(output)
Pan Zezhong's avatar
Pan Zezhong committed
607

wooway777's avatar
wooway777 committed
608
609
610
611
612
613
614
615
616
    def generate(
        self,
        input_content,
        max_steps,
        topp_=1.0,
        topk_=1,
        temperature_=1.0,
        verbose=False,
    ):
Pan Zezhong's avatar
Pan Zezhong committed
617
618
619
620
621
        input_content = self.tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": input_content}],
            add_generation_prompt=True,
            tokenize=False,
        )
PanZezhong's avatar
PanZezhong committed
622
623
        print(input_content, end="", flush=True)
        tokens = self.tokenizer.encode(input_content)
624
625
626
627
628
629
630
631
632
633
        infer_task = InferTask(
            0,
            tokens,
            self.max_context_len(),
            temperature_,
            topk_,
            topp_,
            self.eos_token_id,
        )
        infer_task.bind_kvcache(KVCache(self))
PanZezhong's avatar
PanZezhong committed
634
635

        steps = 0
Pan Zezhong's avatar
Pan Zezhong committed
636
        total_time = 0
wooway777's avatar
wooway777 committed
637
638
        prefill_time = 0
        decode_time = 0
639
        output_content = ""
Pan Zezhong's avatar
Pan Zezhong committed
640

wooway777's avatar
wooway777 committed
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
        # Prefill phase - process initial prompt
        prefill_start_time = time.time()
        output_tokens = self.batch_infer_one_round([infer_task])
        prefill_end_time = time.time()
        prefill_time = prefill_end_time - prefill_start_time
        steps += 1

        output_str = self.tokenizer.decode(output_tokens[0])
        output_content += output_str
        print(output_str, end="", flush=True)
        if output_tokens[0] in self.eos_token_id:
            # If generation ends after prefill, calculate metrics
            total_time = prefill_time
            total_tokens = len(tokens) + 1  # input tokens + first output token

            print("\n")
            print(f"Time per step: {total_time * 1000:.3f}ms")

            if verbose:
                overall_throughput = total_tokens / total_time
                prefill_throughput = len(tokens) / prefill_time
                decode_throughput = 1 / 0.001  # Avoid division by zero, use small value

                print("=" * 50)
                print("PERFORMANCE METRICS")
                print("=" * 50)
                print(f"Input tokens: {len(tokens)}")
                print(f"Generated tokens: 1")
                print(f"Total tokens: {total_tokens}")
                print(f"Total time: {total_time * 1000:.3f}ms")
                print(f"Prefill time: {prefill_time * 1000:.3f}ms")
                print(f"Decode time: 0.000ms")
                print("-" * 50)
                print(f"Time per step: {total_time * 1000:.3f}ms")
                print(
                    f"Avg prefill time per token: {prefill_time * 1000 / len(tokens):.3f}ms"
                )
                print(f"Avg decode time per token: N/A")
                print("-" * 50)
                print(f"Overall throughput: {overall_throughput:.2f} tokens/s")
                print(f"Prefill throughput: {prefill_throughput:.2f} tokens/s")
                print(f"Decode throughput: N/A")
                print("=" * 50)

            return output_content, total_time * 1000

        infer_task.next(output_tokens[0])

        # Decode phase - generate subsequent tokens
        decode_start_time = time.time()
        for step_i in range(1, max_steps):
Pan Zezhong's avatar
Pan Zezhong committed
692
            start_time = time.time()
693
            output_tokens = self.batch_infer_one_round([infer_task])
PanZezhong's avatar
PanZezhong committed
694
            end_time = time.time()
695
            steps += 1
696
697
            output_str = self.tokenizer.decode(output_tokens[0])

PanZezhong's avatar
PanZezhong committed
698
699
            output_content += output_str
            print(output_str, end="", flush=True)
PanZezhong's avatar
PanZezhong committed
700
701
            if output_tokens[0] in self.eos_token_id:
                break
702
            infer_task.next(output_tokens[0])
Pan Zezhong's avatar
Pan Zezhong committed
703

Pan Zezhong's avatar
Pan Zezhong committed
704
705
            if step_i > 0:
                total_time += end_time - start_time
PanZezhong's avatar
PanZezhong committed
706

wooway777's avatar
wooway777 committed
707
708
709
        decode_end_time = time.time()
        decode_time = decode_end_time - decode_start_time

PanZezhong's avatar
PanZezhong committed
710
        print("\n")
wooway777's avatar
wooway777 committed
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762

        # Calculate performance metrics
        total_time = prefill_time + decode_time
        input_tokens = len(tokens)
        generated_tokens = steps  # including first token from prefill

        # Time per token calculations
        avg_time_per_step = (
            total_time * 1000 / (steps - 1) if steps > 1 else total_time * 1000
        )

        print(f"Time per step: {avg_time_per_step:.3f}ms")

        # Only print detailed metrics if verbose flag is set
        if verbose:
            total_tokens = input_tokens + generated_tokens

            # Throughput calculations
            overall_throughput = total_tokens / total_time  # tokens per second
            prefill_throughput = input_tokens / prefill_time if prefill_time > 0 else 0
            decode_throughput = (
                (generated_tokens - 1) / decode_time if decode_time > 0 else 0
            )  # exclude first token from prefill

            # Time per token calculations
            avg_prefill_time_per_token = (
                prefill_time * 1000 / input_tokens if input_tokens > 0 else 0
            )
            avg_decode_time_per_token = (
                decode_time * 1000 / (generated_tokens - 1)
                if generated_tokens > 1
                else 0
            )

            print("=" * 50)
            print("PERFORMANCE METRICS")
            print("=" * 50)
            print(f"Input tokens: {input_tokens}")
            print(f"Generated tokens: {generated_tokens}")
            print(f"Total tokens: {total_tokens}")
            print(f"Total time: {total_time * 1000:.3f}ms")
            print(f"Prefill time: {prefill_time * 1000:.3f}ms")
            print(f"Decode time: {decode_time * 1000:.3f}ms")
            print("-" * 50)
            print(f"Time per step: {avg_time_per_step:.3f}ms")
            print(f"Avg prefill time per token: {avg_prefill_time_per_token:.3f}ms")
            print(f"Avg decode time per token: {avg_decode_time_per_token:.3f}ms")
            print("-" * 50)
            print(f"Overall throughput: {overall_throughput:.2f} tokens/s")
            print(f"Prefill throughput: {prefill_throughput:.2f} tokens/s")
            print(f"Decode throughput: {decode_throughput:.2f} tokens/s")
            print("=" * 50)
763
764

        infer_task._kv_cache.drop(self)
wooway777's avatar
wooway777 committed
765
        return output_content, avg_time_per_step
Pan Zezhong's avatar
Pan Zezhong committed
766

PanZezhong's avatar
PanZezhong committed
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
    def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10):
        tasks = [
            InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id)
            for i in range(batch_size)
        ]
        kv_caches = [KVCache(self) for _ in range(batch_size)]

        nll = 0.0
        total_len = 0

        for i in range(0, len(test_sequences), batch_size):
            batch_id = 0
            true_tokens = []
            while batch_id < batch_size and batch_id + i < len(test_sequences):
                input_tokens = test_sequences[i + batch_id][:-1]
                true_tokens.extend(test_sequences[i + batch_id][1:])
                tasks[batch_id].tokens = input_tokens
                tasks[batch_id].bind_kvcache(kv_caches[batch_id])
                batch_id += 1

            batch_inputs = JiugeBatchedTask(tasks[:batch_id])
            logits = torch.zeros(
                (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits
            )
791
            self.jiuge_model.forward_batch(
PanZezhong's avatar
PanZezhong committed
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
                self.model_instance,
                batch_inputs.tokens,
                batch_inputs.ntok,
                batch_inputs.req_lens,
                batch_inputs.nreq,
                batch_inputs.req_pos,
                batch_inputs.kv_caches,
                logits.data_ptr(),
            )

            logits = logits.float()
            token_ids = torch.tensor(true_tokens, dtype=torch.int64)  # [ntok,]
            log_probs = torch.nn.functional.log_softmax(logits, dim=-1)  # (ntok, vocab)
            token_logprobs = log_probs[
                torch.arange(batch_inputs.ntok), token_ids
            ]  # (ntok,)

            start = 0
            for l in batch_inputs.req_lens_list:
                nll += -token_logprobs[start : start + l].sum().item()
                start += l
            total_len += token_logprobs.numel()

        for task in tasks:
            task.release_kvcache()

        return math.exp(nll / total_len)

PanZezhong's avatar
PanZezhong committed
820
    def destroy_model_instance(self):
821
        self.jiuge_model.destroy_model(self.model_instance)
PanZezhong's avatar
PanZezhong committed
822
        print("Model destroyed")
PanZezhong's avatar
PanZezhong committed
823
824
825
826
827


def test():
    if len(sys.argv) < 3:
        print(
xgqdut2016's avatar
xgqdut2016 committed
828
            "Usage: python jiuge.py [--cpu | --nvidia| --qy| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> [n_device] [--verbose]"
PanZezhong's avatar
PanZezhong committed
829
830
        )
        sys.exit(1)
wooway777's avatar
wooway777 committed
831
832

    # Parse command line arguments
PanZezhong's avatar
PanZezhong committed
833
834
    model_path = sys.argv[2]
    device_type = DeviceType.DEVICE_TYPE_CPU
wooway777's avatar
wooway777 committed
835
836
837
838
839
840
841
842
    verbose = False

    # Check for verbose flag
    for arg in sys.argv:
        if arg == "--verbose":
            verbose = True
            break

PanZezhong's avatar
PanZezhong committed
843
844
845
846
    if sys.argv[1] == "--cpu":
        device_type = DeviceType.DEVICE_TYPE_CPU
    elif sys.argv[1] == "--nvidia":
        device_type = DeviceType.DEVICE_TYPE_NVIDIA
xgqdut2016's avatar
xgqdut2016 committed
847
848
    elif sys.argv[1] == "--qy":
        device_type = DeviceType.DEVICE_TYPE_QY
PanZezhong's avatar
PanZezhong committed
849
850
851
852
853
854
855
856
    elif sys.argv[1] == "--cambricon":
        device_type = DeviceType.DEVICE_TYPE_CAMBRICON
    elif sys.argv[1] == "--ascend":
        device_type = DeviceType.DEVICE_TYPE_ASCEND
    elif sys.argv[1] == "--metax":
        device_type = DeviceType.DEVICE_TYPE_METAX
    elif sys.argv[1] == "--moore":
        device_type = DeviceType.DEVICE_TYPE_MOORE
zhangyue's avatar
zhangyue committed
857
858
    elif sys.argv[1] == "--iluvatar":
        device_type = DeviceType.DEVICE_TYPE_ILUVATAR
zhangyue's avatar
zhangyue committed
859
860
    elif sys.argv[1] == "--kunlun":
        device_type = DeviceType.DEVICE_TYPE_KUNLUN
zhuyue's avatar
zhuyue committed
861
862
    elif sys.argv[1] == "--hygon":
        device_type = DeviceType.DEVICE_TYPE_HYGON
wooway777's avatar
wooway777 committed
863
864
    elif sys.argv[1] == "--ali":
        device_type = DeviceType.DEVICE_TYPE_ALI
PanZezhong's avatar
PanZezhong committed
865
866
    else:
        print(
wooway777's avatar
wooway777 committed
867
            "Usage: python jiuge.py [--cpu | --nvidia| --qy| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon | --ali] <path/to/model_dir> [n_device] [--verbose]"
PanZezhong's avatar
PanZezhong committed
868
869
870
        )
        sys.exit(1)

wooway777's avatar
wooway777 committed
871
872
873
874
    # Find n_device argument (skip --verbose)
    ndev_args = [arg for arg in sys.argv[3:] if arg != "--verbose"]
    ndev = int(ndev_args[0]) if ndev_args else 1

PanZezhong's avatar
PanZezhong committed
875
    model = JiugeForCauslLM(model_path, device_type, ndev)
wooway777's avatar
wooway777 committed
876
    model.generate("山东最高的山是?", 500, verbose=verbose)
PanZezhong's avatar
PanZezhong committed
877
    model.destroy_model_instance()
PanZezhong's avatar
PanZezhong committed
878
879
880
881


if __name__ == "__main__":
    test()