jiuge.py 21.3 KB
Newer Older
Pan Zezhong's avatar
Pan Zezhong committed
1
from typing import List
PanZezhong's avatar
PanZezhong committed
2
from libinfinicore_infer import (
3
4
5
    JiugeMetaCStruct,
    JiugeWeightsCStruct,
    KVCacheCStruct,
PanZezhong's avatar
PanZezhong committed
6
7
8
    DataType,
    DeviceType,
    create_jiuge_model,
PanZezhong's avatar
PanZezhong committed
9
    destroy_jiuge_model,
PanZezhong's avatar
PanZezhong committed
10
11
12
13
    create_kv_cache,
    drop_kv_cache,
    infer_batch,
)
14
from infer_task import InferTask, KVCache
Pan Zezhong's avatar
Pan Zezhong committed
15
16
17
18
19
20
21
22

from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref
import os
from pathlib import Path
import safetensors
import sys
import time
import json
23
import math
PanZezhong's avatar
PanZezhong committed
24
25
26
import torch
import transformers

Pan Zezhong's avatar
Pan Zezhong committed
27
torch.set_default_device("cpu")
PanZezhong's avatar
PanZezhong committed
28

29

PanZezhong's avatar
PanZezhong committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class LlamaWeightsNaming:
    def input_embd(self):
        return "model.embed_tokens.weight"

    def output_norm(self):
        return "model.norm.weight"

    def output_embd(self):
        return "lm_head.weight"

    def attn_norm(self, i):
        return f"model.layers.{i}.input_layernorm.weight"

    def attn_q(self, i):
        return f"model.layers.{i}.self_attn.q_proj.weight"

    def attn_k(self, i):
        return f"model.layers.{i}.self_attn.k_proj.weight"

    def attn_v(self, i):
        return f"model.layers.{i}.self_attn.v_proj.weight"

    def attn_o(self, i):
        return f"model.layers.{i}.self_attn.o_proj.weight"

    def attn_q_b(self, i):
        return f"model.layers.{i}.self_attn.q_proj.bias"

    def attn_k_b(self, i):
        return f"model.layers.{i}.self_attn.k_proj.bias"

    def attn_v_b(self, i):
        return f"model.layers.{i}.self_attn.v_proj.bias"

PanZezhong's avatar
PanZezhong committed
64
65
66
    def ffn_norm(self, i):
        return f"model.layers.{i}.post_attention_layernorm.weight"

PanZezhong's avatar
PanZezhong committed
67
68
69
70
71
72
73
74
75
    def gate(self, i):
        return f"model.layers.{i}.mlp.gate_proj.weight"

    def up(self, i):
        return f"model.layers.{i}.mlp.up_proj.weight"

    def down(self, i):
        return f"model.layers.{i}.mlp.down_proj.weight"

PanZezhong's avatar
PanZezhong committed
76
77
78
79
80
81
    def match(state_dict):
        return (
            "model.norm.weight" in state_dict
            and "model.layers.0.self_attn.q_proj.weight" in state_dict
        )

PanZezhong's avatar
PanZezhong committed
82

83
class JiugeMetaFromLlama(JiugeMetaCStruct):
84
    def __init__(self, config, dtype=torch.float16, max_tokens=None):
Pan Zezhong's avatar
Pan Zezhong committed
85
        if dtype == torch.float16:
PanZezhong's avatar
PanZezhong committed
86
87
88
            dt_ = DataType.INFINI_DTYPE_F16
        elif dtype == torch.float32:
            dt_ = DataType.INFINI_DTYPE_F32
PanZezhong's avatar
PanZezhong committed
89
90
        elif dtype == torch.bfloat16:
            dt_ = DataType.INFINI_DTYPE_BF16
PanZezhong's avatar
PanZezhong committed
91
92
        else:
            dt_ = DataType.INFINI_DTYPE_F16
93
94
95
96
97
98
99
100
101
102
103

        scale_input = 1.0
        scale_output = 1.0
        scale_o = 1.0
        scale_down = 1.0
        if "fm9g" == config["model_type"]:
            scale_input = config["scale_emb"]
            scale_output = config["hidden_size"] // config["dim_model_base"]
            scale_o = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])
            scale_down = config["scale_depth"] / math.sqrt(config["num_hidden_layers"])

PanZezhong's avatar
PanZezhong committed
104
        super().__init__(
PanZezhong's avatar
PanZezhong committed
105
            dt_logits=dt_,
Pan Zezhong's avatar
Pan Zezhong committed
106
107
108
            nlayer=config["num_hidden_layers"],
            d=config["hidden_size"],
            nh=config["num_attention_heads"],
PanZezhong's avatar
PanZezhong committed
109
            nkvh=(
Pan Zezhong's avatar
Pan Zezhong committed
110
111
112
                config["num_key_value_heads"]
                if "num_key_value_heads" in config
                else config["num_attention_heads"]
PanZezhong's avatar
PanZezhong committed
113
            ),
Pan Zezhong's avatar
Pan Zezhong committed
114
115
            dh=config["hidden_size"] // config["num_attention_heads"],
            di=config["intermediate_size"],
116
117
118
            dctx=(
                config["max_position_embeddings"] if max_tokens is None else max_tokens
            ),
Pan Zezhong's avatar
Pan Zezhong committed
119
120
121
            dvoc=config["vocab_size"],
            epsilon=config["rms_norm_eps"],
            theta=(config["rope_theta"] if "rope_theta" in config else 100000.0),
PanZezhong's avatar
PanZezhong committed
122
123
            end_token=2,
        )
PanZezhong's avatar
PanZezhong committed
124
        self.torch_dtype_logits = dtype
PanZezhong's avatar
PanZezhong committed
125
126


127
class JiugeWeightsImpl(JiugeWeightsCStruct):
Pan Zezhong's avatar
Pan Zezhong committed
128
129
130
131
132
133
134
135
    def __init__(
        self,
        meta,
        naming,
        state_dict,
        torch_dt_mat=torch.float16,
        torch_dt_norm=torch.float32,
        ndev=1,
PanZezhong's avatar
PanZezhong committed
136
        transpose_weight=True,
Pan Zezhong's avatar
Pan Zezhong committed
137
    ):
PanZezhong's avatar
PanZezhong committed
138
139
140
141
142
143
        nlayer = meta.nlayer
        nh = meta.nh
        nkvh = meta.nkvh
        dh = meta.dh
        d = meta.d
        di = meta.di
144
145
146
147
        scale_input = meta.scale_input
        scale_output = meta.scale_output
        scale_o = meta.scale_o
        scale_down = meta.scale_down
PanZezhong's avatar
PanZezhong committed
148
149
150
151
        assert nh % nkvh == 0
        assert nh % ndev == 0
        assert nkvh % ndev == 0
        assert di % ndev == 0
PanZezhong's avatar
PanZezhong committed
152
153
154
155
156
        torch_dt_logits = meta.torch_dtype_logits
        if torch_dt_mat == torch.float16:
            self.dt_mat = DataType.INFINI_DTYPE_F16
        elif torch_dt_mat == torch.float32:
            self.dt_mat = DataType.INFINI_DTYPE_F32
PanZezhong's avatar
PanZezhong committed
157
158
        elif torch_dt_mat == torch.bfloat16:
            self.dt_mat = DataType.INFINI_DTYPE_BF16
PanZezhong's avatar
PanZezhong committed
159
160
161
162
163
164
        else:
            raise ValueError("Unsupported proj weight data type")
        if torch_dt_norm == torch.float16:
            self.dt_norm = DataType.INFINI_DTYPE_F16
        elif torch_dt_norm == torch.float32:
            self.dt_norm = DataType.INFINI_DTYPE_F32
PanZezhong's avatar
PanZezhong committed
165
166
        elif torch_dt_norm == torch.bfloat16:
            self.dt_norm = DataType.INFINI_DTYPE_BF16
PanZezhong's avatar
PanZezhong committed
167
168
        else:
            raise ValueError("Unsupported norm weight data type")
PanZezhong's avatar
PanZezhong committed
169

Pan Zezhong's avatar
Pan Zezhong committed
170
171
172
173
174
175
176
177
178
179
        input_embd_naming = (
            naming.input_embd()
            if naming.input_embd() in state_dict
            else naming.output_embd()
        )
        output_embd_naming = (
            naming.output_embd()
            if naming.output_embd() in state_dict
            else naming.input_embd()
        )
PanZezhong's avatar
PanZezhong committed
180
        self.transpose_linear_weights = 1 if transpose_weight else 0
PanZezhong's avatar
PanZezhong committed
181
        self.nlayer = nlayer
182
183
184
        self.input_embd_tensor = (
            state_dict[input_embd_naming].to(torch_dt_logits) * scale_input
        )
PanZezhong's avatar
PanZezhong committed
185
        self.input_embd = self.input_embd_tensor.data_ptr()
186
187
188
        self.output_norm_tensor = (
            state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output
        )
PanZezhong's avatar
PanZezhong committed
189
        self.output_norm = self.output_norm_tensor.data_ptr()
Pan Zezhong's avatar
Pan Zezhong committed
190
        self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat)
PanZezhong's avatar
PanZezhong committed
191
        if not transpose_weight:
192
193
194
            self.output_embd_tensor = self.output_embd_tensor.transpose(
                0, 1
            ).contiguous()
PanZezhong's avatar
PanZezhong committed
195
196
197
        self.output_embd = self.output_embd_tensor.data_ptr()

        self.attn_norm_tensors = [
PanZezhong's avatar
PanZezhong committed
198
            state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
PanZezhong's avatar
PanZezhong committed
199
200
201
202
203
        ]
        self.attn_norm_ptrs = [
            self.attn_norm_tensors[i].data_ptr() for i in range(nlayer)
        ]
        self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs)
PanZezhong's avatar
PanZezhong committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

        def qkv_slices(_i):
            _Q = (
                state_dict[naming.attn_q(_i)]
                .reshape([nh, 2, dh // 2, d])
                .transpose(1, 2)
            )
            _K = (
                state_dict[naming.attn_k(_i)]
                .reshape([nkvh, 2, dh // 2, d])
                .transpose(1, 2)
            )
            _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d])
            _result = []
            _nh = nh // ndev
            _nkvh = nkvh // ndev
            for _idev in range(ndev):
                _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :])
                _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :])
                _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :])
            return _result

Pan Zezhong's avatar
Pan Zezhong committed
226
227
228
        self.qkv_tensor = [
            torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
229
230
        if not transpose_weight:
            for i in range(nlayer):
231
232
233
234
235
236
                self.qkv_tensor[i] = (
                    self.qkv_tensor[i]
                    .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d)
                    .transpose(1, 2)
                    .contiguous()
                )
PanZezhong's avatar
PanZezhong committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)]
        self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs)

        def qkv_b_slices(_i):
            _QB = (
                state_dict[naming.attn_q_b(_i)]
                .reshape([nh, 2, dh // 2])
                .transpose(1, 2)
            )
            _KB = (
                state_dict[naming.attn_k_b(_i)]
                .reshape([nkvh, 2, dh // 2])
                .transpose(1, 2)
            )
            _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2])
            _result = []
            _nh = nh // ndev
            _nkvh = nkvh // ndev
            for _idev in range(ndev):
Pan Zezhong's avatar
Pan Zezhong committed
256
257
258
                _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten())
                _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
                _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten())
PanZezhong's avatar
PanZezhong committed
259
260
261
            return _result

        if naming.attn_q_b(0) in state_dict:
Pan Zezhong's avatar
Pan Zezhong committed
262
263
264
            self.qkv_b_tensors = [
                torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer)
            ]
PanZezhong's avatar
PanZezhong committed
265
266
267
268
269
270
271
            self.qkv_b_tensor_ptrs = [
                self.qkv_b_tensors[i].data_ptr() for i in range(nlayer)
            ]
            self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs)
        else:
            self.attn_qkv_b = None

PanZezhong's avatar
PanZezhong committed
272
        self.attn_o_tensor = [
273
274
            (
                state_dict[naming.attn_o(i)]
PanZezhong's avatar
PanZezhong committed
275
276
277
278
                .to(torch_dt_mat)
                .reshape([d, ndev, nh // ndev * dh])
                .transpose(0, 1)
                .contiguous()
279
280
281
282
283
284
                if transpose_weight
                else state_dict[naming.attn_o(i)]
                .transpose(0, 1)
                .to(torch_dt_mat)
                .contiguous()
            )
285
            * scale_o
PanZezhong's avatar
PanZezhong committed
286
287
            for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
288
289
290
        self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)]
        self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs)

Pan Zezhong's avatar
Pan Zezhong committed
291
292
293
        self.ffn_norm_tensors = [
            state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
294
295
296
297
        self.ffn_norm_ptrs = [
            self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer)
        ]
        self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs)
PanZezhong's avatar
PanZezhong committed
298
299
300
301
302
303
304
305
306
307
308

        def gate_up_slices(_i):
            _result = []
            _di = di // ndev
            for _idev in range(ndev):
                _start = _idev * _di
                _end = (_idev + 1) * _di
                _result.append(state_dict[naming.gate(_i)][_start:_end, :])
                _result.append(state_dict[naming.up(_i)][_start:_end, :])
            return _result

Pan Zezhong's avatar
Pan Zezhong committed
309
310
311
        self.gate_up_tensors = [
            torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
312
313
        if not transpose_weight:
            for i in range(nlayer):
314
315
316
317
318
319
                self.gate_up_tensors[i] = (
                    self.gate_up_tensors[i]
                    .reshape(ndev, 2 * di // ndev, d)
                    .transpose(1, 2)
                    .contiguous()
                )
PanZezhong's avatar
PanZezhong committed
320
321
        self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)]
        self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs)
PanZezhong's avatar
PanZezhong committed
322
323

        self.ffn_down_tensor = [
324
325
326
327
328
329
330
331
332
333
334
335
            (
                state_dict[naming.down(i)]
                .to(torch_dt_mat)
                .reshape([d, ndev, di // ndev])
                .transpose(0, 1)
                .contiguous()
                if transpose_weight
                else state_dict[naming.down(i)]
                .transpose(0, 1)
                .to(torch_dt_mat)
                .contiguous()
            )
336
            * scale_down
PanZezhong's avatar
PanZezhong committed
337
338
            for i in range(nlayer)
        ]
PanZezhong's avatar
PanZezhong committed
339
340
        self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)]
        self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs)
PanZezhong's avatar
PanZezhong committed
341
342


Pan Zezhong's avatar
Pan Zezhong committed
343
344
345
346
347
348
349
350
351
class JiugeBatchedTask:
    def __init__(self, tasks: List[InferTask]):
        self.tasks = tasks
        self.nreq = len(tasks)

        # Precompute fields
        token_lists = [t.tokens for t in tasks]
        self.req_lens_list = [len(toks) for toks in token_lists]
        self.req_pos_list = [t.pos for t in tasks]
352
        self.kv_cache_ptrs = [t.kvcache().data() for t in tasks]
Pan Zezhong's avatar
Pan Zezhong committed
353
354
355
356
357
358
359
360
361
362
363
364
        self.temperaturas_list = [t.temperature for t in tasks]
        self.topks_list = [t.topk for t in tasks]
        self.topps_list = [t.topp for t in tasks]

        # Flatten token lists
        flat_tokens = [tok for toks in token_lists for tok in toks]
        self.ntok = len(flat_tokens)

        # Convert to ctypes arrays in one pass
        self.tokens = (c_uint * self.ntok)(*flat_tokens)
        self.req_lens = (c_uint * self.nreq)(*self.req_lens_list)
        self.req_pos = (c_uint * self.nreq)(*self.req_pos_list)
365
        self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs)
Pan Zezhong's avatar
Pan Zezhong committed
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
        self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list)
        self.topks = (c_uint * self.nreq)(*self.topks_list)
        self.topps = (c_float * self.nreq)(*self.topps_list)

    def input_args(self):
        return (
            self.tokens,
            self.ntok,
            self.req_lens,
            self.nreq,
            self.req_pos,
            self.kv_caches,
            self.temperaturas,
            self.topks,
            self.topps,
        )


PanZezhong's avatar
PanZezhong committed
384
class JiugeForCauslLM:
385
386
387
    def __init__(
        self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None
    ):
PanZezhong's avatar
PanZezhong committed
388
        def load_all_safetensors_from_dir(dir_path_: str):
PanZezhong's avatar
PanZezhong committed
389
390
391
392
393
            tensors_ = {}
            dir_path_ = Path(dir_path_)
            for file in sorted(dir_path_.glob("*.safetensors")):
                data_ = safetensors.safe_open(file, "pt")
                for name_ in data_.keys():
PanZezhong's avatar
PanZezhong committed
394
                    tensors_[name_] = data_.get_tensor(name_)
PanZezhong's avatar
PanZezhong committed
395
            return tensors_
396

PanZezhong's avatar
PanZezhong committed
397
398
        print("Loading model weights to host...")
        load_start_time = time.time()
PanZezhong's avatar
PanZezhong committed
399

Pan Zezhong's avatar
Pan Zezhong committed
400
401
        with open(os.path.join(model_dir_path, "config.json"), "r") as f:
            config = json.load(f)
PanZezhong's avatar
PanZezhong committed
402
403
            self.config = config
        eos_token_id = self.config["eos_token_id"]
404
405
406
407
408
409
        self.eos_token_id = (
            [eos_token_id] if type(eos_token_id) == int else eos_token_id
        )
        transpose_weight = (
            device != DeviceType.DEVICE_TYPE_ASCEND
        )  # y = xW is faster than y=xW^T on Ascend
Pan Zezhong's avatar
Pan Zezhong committed
410
        if "llama" == config["model_type"]:
411
412
413
414
415
            model = (
                transformers.LlamaForCausalLM.from_pretrained(model_dir_path)
                .cpu()
                .half()
            )
416
            self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
PanZezhong's avatar
PanZezhong committed
417
418
            self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
            self.weights = JiugeWeightsImpl(
419
420
421
422
423
                self.meta,
                LlamaWeightsNaming(),
                model.state_dict(),
                ndev=ndev,
                transpose_weight=transpose_weight,
PanZezhong's avatar
PanZezhong committed
424
            )
Pan Zezhong's avatar
Pan Zezhong committed
425
        elif "fm9g" == config["model_type"]:
426
427
428
            if any(
                file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
            ):
PanZezhong's avatar
PanZezhong committed
429
430
431
                state_dict = load_all_safetensors_from_dir(model_dir_path)
            else:
                state_dict = torch.load(
432
433
434
435
                    os.path.join(model_dir_path, "pytorch_model.bin"),
                    weights_only=True,
                    map_location="cpu",
                )
PanZezhong's avatar
PanZezhong committed
436
            if LlamaWeightsNaming.match(state_dict):
437
                self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
PanZezhong's avatar
PanZezhong committed
438
                self.weights = JiugeWeightsImpl(
439
440
441
442
443
                    self.meta,
                    LlamaWeightsNaming(),
                    state_dict,
                    ndev=ndev,
                    transpose_weight=transpose_weight,
PanZezhong's avatar
PanZezhong committed
444
445
446
447
                )
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_dir_path, trust_remote_code=True
                )
PanZezhong's avatar
PanZezhong committed
448
449
            else:
                raise ValueError("Unsupported weight naming")
Pan Zezhong's avatar
Pan Zezhong committed
450
        elif "fm9g7b" == config["model_type"]:
Pan Zezhong's avatar
Pan Zezhong committed
451
452
453
454
455
456
457
458
459
460
            if any(
                file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
            ):
                state_dict = load_all_safetensors_from_dir(model_dir_path)
            else:
                state_dict = torch.load(
                    os.path.join(model_dir_path, "pytorch_model.bin"),
                    weights_only=True,
                    map_location="cpu",
                )
Pan Zezhong's avatar
Pan Zezhong committed
461
            if LlamaWeightsNaming.match(state_dict):
462
                self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
Pan Zezhong's avatar
Pan Zezhong committed
463
                self.weights = JiugeWeightsImpl(
464
465
466
467
468
                    self.meta,
                    LlamaWeightsNaming(),
                    state_dict,
                    ndev=ndev,
                    transpose_weight=transpose_weight,
Pan Zezhong's avatar
Pan Zezhong committed
469
470
471
472
                )
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_dir_path, trust_remote_code=True
                )
PanZezhong's avatar
PanZezhong committed
473
474
475
476
477
            else:
                raise ValueError("Unsupported weight naming")
        elif "qwen2" == config["model_type"]:
            state_dict = load_all_safetensors_from_dir(model_dir_path)
            if LlamaWeightsNaming.match(state_dict):
478
                self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens)
PanZezhong's avatar
PanZezhong committed
479
                self.weights = JiugeWeightsImpl(
480
481
482
483
484
                    self.meta,
                    LlamaWeightsNaming(),
                    state_dict,
                    ndev=ndev,
                    transpose_weight=transpose_weight,
PanZezhong's avatar
PanZezhong committed
485
486
487
488
                )
                self.tokenizer = transformers.AutoTokenizer.from_pretrained(
                    model_dir_path
                )
PanZezhong's avatar
PanZezhong committed
489
490
        else:
            raise ValueError("Unsupported model architecture")
PanZezhong's avatar
PanZezhong committed
491
492
493

        load_end_time = time.time()
        print(f"Time used: {load_end_time - load_start_time:.3f}s")
494

PanZezhong's avatar
PanZezhong committed
495
496
        print(f"Creating model on {ndev} devices...")
        load_start_time = time.time()
PanZezhong's avatar
PanZezhong committed
497
        dev_ids = (c_int * ndev)(*[i for i in range(ndev)])
PanZezhong's avatar
PanZezhong committed
498
499
500
501
502
503
504
        self.model_instance = create_jiuge_model(
            byref(self.meta),
            byref(self.weights),
            device,
            ndev,
            dev_ids,
        )
PanZezhong's avatar
PanZezhong committed
505
506
        load_end_time = time.time()
        print(f"Time used: {load_end_time - load_start_time:.3f}s")
507

Pan Zezhong's avatar
Pan Zezhong committed
508
509
    def max_context_len(self):
        return self.meta.dctx
510

Pan Zezhong's avatar
Pan Zezhong committed
511
512
513
514
515
516
    def create_kv_cache(self):
        return create_kv_cache(self.model_instance)

    def drop_kv_cache(self, kv_cache):
        drop_kv_cache(self.model_instance, kv_cache)

Pan Zezhong's avatar
Pan Zezhong committed
517
518
519
520
521
522
523
    def batch_infer_one_round(self, tasks: List[InferTask]):
        output = (c_uint * len(tasks))()
        batch_inputs = JiugeBatchedTask(tasks)
        infer_batch(
            self.model_instance,
            *(batch_inputs.input_args()),
            output,
Pan Zezhong's avatar
Pan Zezhong committed
524
        )
Pan Zezhong's avatar
Pan Zezhong committed
525
        return list(output)
Pan Zezhong's avatar
Pan Zezhong committed
526

Pan Zezhong's avatar
Pan Zezhong committed
527
    def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0):
Pan Zezhong's avatar
Pan Zezhong committed
528
529
530
531
532
        input_content = self.tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": input_content}],
            add_generation_prompt=True,
            tokenize=False,
        )
PanZezhong's avatar
PanZezhong committed
533
534
        print(input_content, end="", flush=True)
        tokens = self.tokenizer.encode(input_content)
535
536
537
538
539
540
541
542
543
544
        infer_task = InferTask(
            0,
            tokens,
            self.max_context_len(),
            temperature_,
            topk_,
            topp_,
            self.eos_token_id,
        )
        infer_task.bind_kvcache(KVCache(self))
PanZezhong's avatar
PanZezhong committed
545
546

        steps = 0
Pan Zezhong's avatar
Pan Zezhong committed
547
        total_time = 0
548
        output_content = ""
Pan Zezhong's avatar
Pan Zezhong committed
549
550
551

        for step_i in range(max_steps):
            start_time = time.time()
552
            output_tokens = self.batch_infer_one_round([infer_task])
PanZezhong's avatar
PanZezhong committed
553
            end_time = time.time()
554
            steps += 1
PanZezhong's avatar
PanZezhong committed
555
556
557
558
559
560
561
            output_str = (
                self.tokenizer._tokenizer.id_to_token(output_tokens[0])
                .replace("▁", " ")
                .replace("<0x0A>", "\n")
            )
            output_content += output_str
            print(output_str, end="", flush=True)
PanZezhong's avatar
PanZezhong committed
562
563
            if output_tokens[0] in self.eos_token_id:
                break
564
            infer_task.next(output_tokens[0])
Pan Zezhong's avatar
Pan Zezhong committed
565

Pan Zezhong's avatar
Pan Zezhong committed
566
567
            if step_i > 0:
                total_time += end_time - start_time
PanZezhong's avatar
PanZezhong committed
568
569

        print("\n")
Pan Zezhong's avatar
Pan Zezhong committed
570
        avg_time = total_time * 1000 / (steps - 1)
PanZezhong's avatar
PanZezhong committed
571
        print(f"Time per step: {avg_time:.3f}ms")
572
573

        infer_task._kv_cache.drop(self)
PanZezhong's avatar
PanZezhong committed
574
        return output_content, avg_time
Pan Zezhong's avatar
Pan Zezhong committed
575

PanZezhong's avatar
PanZezhong committed
576
577
578
    def destroy_model_instance(self):
        destroy_jiuge_model(self.model_instance)
        print("Model destroyed")
PanZezhong's avatar
PanZezhong committed
579
580
581
582
583


def test():
    if len(sys.argv) < 3:
        print(
Pan Zezhong's avatar
Pan Zezhong committed
584
            "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
PanZezhong's avatar
PanZezhong committed
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
        )
        sys.exit(1)
    model_path = sys.argv[2]
    device_type = DeviceType.DEVICE_TYPE_CPU
    if sys.argv[1] == "--cpu":
        device_type = DeviceType.DEVICE_TYPE_CPU
    elif sys.argv[1] == "--nvidia":
        device_type = DeviceType.DEVICE_TYPE_NVIDIA
    elif sys.argv[1] == "--cambricon":
        device_type = DeviceType.DEVICE_TYPE_CAMBRICON
    elif sys.argv[1] == "--ascend":
        device_type = DeviceType.DEVICE_TYPE_ASCEND
    elif sys.argv[1] == "--metax":
        device_type = DeviceType.DEVICE_TYPE_METAX
    elif sys.argv[1] == "--moore":
        device_type = DeviceType.DEVICE_TYPE_MOORE
    else:
        print(
Pan Zezhong's avatar
Pan Zezhong committed
603
            "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] <path/to/model_dir> [n_device]"
PanZezhong's avatar
PanZezhong committed
604
605
606
607
608
        )
        sys.exit(1)

    ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1
    model = JiugeForCauslLM(model_path, device_type, ndev)
Pan Zezhong's avatar
Pan Zezhong committed
609
    model.generate("山东最高的山是?", 500)
PanZezhong's avatar
PanZezhong committed
610
    model.destroy_model_instance()
PanZezhong's avatar
PanZezhong committed
611
612
613
614


if __name__ == "__main__":
    test()