jiuge.py 10.2 KB
Newer Older
1
import infinicore
2
import transformers
3
4
5
6
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
7
from infinilm.infer_engine import GenerationConfig, InferEngine
8
9
10
11
import argparse
import sys
import time
import os
12
import numpy as np
13
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
14
from packaging import version
15
16
17

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))

18
19
_PAGED_KV_BLOCK_SIZE = 256

20
21
22
23
24
25
26
27
28
29
30
31
32
33

def get_args():
    parser = argparse.ArgumentParser(description="run Llama args")

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Run cpu test",
    )
    parser.add_argument(
        "--nvidia",
        action="store_true",
        help="Run nvidia test",
    )
34
35
36
37
38
    parser.add_argument(
        "--qy",
        action="store_true",
        help="Run qy test",
    )
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    parser.add_argument(
        "--metax",
        action="store_true",
        help="Run metax test",
    )
    parser.add_argument(
        "--moore",
        action="store_true",
        help="Run moore test",
    )
    parser.add_argument(
        "--iluvatar",
        action="store_true",
        help="Run iluvatar test",
    )
54
55
56
57
58
    parser.add_argument(
        "--cambricon",
        action="store_true",
        help="Run cambricon test",
    )
wooway777's avatar
wooway777 committed
59
60
61
62
63
    parser.add_argument(
        "--ali",
        action="store_true",
        help="Run alippu test",
    )
64
65
66
67
68
    parser.add_argument(
        "--hygon",
        action="store_true",
        help="Run hygon test",
    )
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="model_path",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=100,
        help="max_new_tokens",
    )
    parser.add_argument(
        "--backend",
        type=str,
Your Name's avatar
Your Name committed
84
        default="cpp",
85
86
87
        help="python or cpp model",
    )
    parser.add_argument(
pengcheng888's avatar
pengcheng888 committed
88
        "--batch-size",
89
90
91
92
93
94
95
96
97
98
99
100
101
        type=int,
        default=1,
        help="number of prompts in a batch",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="How are you",
        help="input prompt",
    )
    parser.add_argument(
        "--tp",
        type=int,
Your Name's avatar
Your Name committed
102
        default=1,
103
104
        help="total rank for tensor parallel",
    )
105
106
107
108
109
    parser.add_argument(
        "--enable-paged-attn",
        action="store_true",
        help="use paged cache",
    )
110
111
112
113
114
115
116
117

    parser.add_argument(
        "--paged_kv_block_size",
        type=int,
        default=256,
        help="num tokens each kv block can hold",
    )

118
119
120
121
122
    parser.add_argument(
        "--enable-graph",
        action="store_true",
        help="enable graph compiling",
    )
123

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    parser.add_argument(
        "--top-k",
        type=int,
        default=1,
        help="top k sampling",
    )

    parser.add_argument(
        "--top-p",
        type=float,
        default=1.0,
        help="top p sampling",
    )

    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="sampling temperature",
    )

145
146
147
    parser.add_argument(
        "--attn",
        type=str,
148
        default="default",
149
150
151
152
        choices=["default", "flash-attn"],
        help="attention backend to use: 'default' or 'flash-attn'",
    )

153
154
155
156
157
158
159
160
    return parser.parse_args()


def test(
    prompts: str | list[str],
    model_path,
    max_new_tokens=100,
    infini_device=infinicore.device("cpu", 0),
Your Name's avatar
Your Name committed
161
    tp=1,
162
    enable_paged_attn=False,
163
    enable_graph=False,
164
165
166
    top_k=1,
    top_p=1.0,
    temperature=1.0,
167
    attn_backend="default",
168
169
170
):
    model_path = os.path.expanduser(model_path)
    # ---------------------------------------------------------------------------- #
171
    #                        Create Model
172
    # ---------------------------------------------------------------------------- #
173
    model = InferEngine(
174
175
        model_path,
        device=infini_device,
Your Name's avatar
Your Name committed
176
        distributed_config=DistConfig(tp),
177
        enable_graph_compiling=enable_graph,
178
        attention_backend=attn_backend,
179
180
    )
    # ---------------------------------------------------------------------------- #
181
    #                        Load Weights
182
    # ---------------------------------------------------------------------------- #
183
    load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
184
185

    # ---------------------------------------------------------------------------- #
186
    #                        create tokenizer
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    # ---------------------------------------------------------------------------- #
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    if "llama" == model.config.model_type:
        backend = getattr(tokenizer, "backend_tokenizer", None)
        target = getattr(backend, "_tokenizer", backend)
        norm = getattr(target, "normalizer", None)
        dec = getattr(target, "decoder", None)
        sn = repr(norm)[:800] if norm is not None else ""
        sd = repr(dec)[:800] if dec is not None else ""
        has_prepend = "Prepend" in sn
        has_strip = "Strip" in sd
        if has_prepend and has_strip:
            target.decoder = _dec.Sequence(
                [
                    _dec.Replace("▁", " "),
                    _dec.ByteFallback(),
                    _dec.Fuse(),
                ]
            )

    # ---------------------------------------------------------------------------- #
208
    #                        tokenize
209
210
211
212
213
214
215
216
217
218
219
220
    # ---------------------------------------------------------------------------- #
    # prompt = "山东最高的山是?"
    if isinstance(prompts, str):
        prompts = [prompts]
    input_contents = [
        tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": prompt}],
            add_generation_prompt=True,
            tokenize=False,
        )
        for prompt in prompts
    ]
PanZezhong's avatar
PanZezhong committed
221

222
223
224
    # input_ids_list = tokenizer.batch_encode_plus(input_contents)[
    #     "input_ids"
    # ]  # List: [[1, 1128, 526, 366, 29892]]
225
    if version.parse(transformers.__version__) < version.parse("5.0.0"):
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        # Ideally this is solved by upgrading transformers. However, doing so causes version mismatch between transformers and mlu pytorch on devices with Phytium CPU. So a branch is temporarily used.
        input_ids_list = [
            tokenizer.encode_plus(
                text, truncation=True, max_length=2048, add_special_tokens=True
            )["input_ids"]
            for text in input_contents
        ]
    else:
        input_ids_list = [
            tokenizer._encode_plus(
                text, truncation=True, max_length=2048, add_special_tokens=True
            )["input_ids"]
            for text in input_contents
        ]
240

241
    # ---------------------------------------------------------------------------- #
242
    #                       Create KVCache
243
244
    # ---------------------------------------------------------------------------- #
    if enable_paged_attn:
245
246
        batch_size = 1 if prompts is str else len(prompts)
        max_total_tokens = max_new_tokens + len(input_ids_list[0])
247
        cache_config = PagedKVCacheConfig(
248
249
250
251
252
            num_blocks=(
                (max_total_tokens + (_PAGED_KV_BLOCK_SIZE - 1)) // _PAGED_KV_BLOCK_SIZE
            )
            * batch_size,
            block_size=_PAGED_KV_BLOCK_SIZE,
253
254
255
256
257
258
259
260
261
        )
    else:
        batch_size = 1 if prompts is str else len(prompts)
        initial_capacity = max_new_tokens + len(input_ids_list[0])
        cache_config = StaticKVCacheConfig(
            max_batch_size=batch_size, max_cache_len=initial_capacity
        )

    model.reset_cache(cache_config)
PanZezhong's avatar
PanZezhong committed
262

263
    # ---------------------------------------------------------------------------- #
264
    #                        Generate
265
    # ---------------------------------------------------------------------------- #
PanZezhong's avatar
PanZezhong committed
266
    print(input_contents[0], end="", flush=True)
267
268
269
270
    input_ids_infini = infinicore.from_list(input_ids_list)

    t1 = time.time()
    print("=================== start generate ====================")
271
    output_ids = model.generate(
272
        input_ids_infini,
273
        GenerationConfig(
274
275
276
277
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
278
279
        ),
        _measure_and_log_time=True,
280
281
282
    )
    t2 = time.time()

283
284
285
    numpy_output_ids = np.array([output_id.to_numpy()[0] for output_id in output_ids])
    print(tokenizer.decode(numpy_output_ids, skip_special_tokens=True))

286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    print(
        f"total_time: {round((t2 - t1) * 1000, 2)} ms",
    )


if __name__ == "__main__":
    args = get_args()
    print(args)

    # Parse command line arguments
    device_str = "cpu"
    if args.cpu:
        device_str = "cpu"
    elif args.nvidia:
        device_str = "cuda"
301
302
    elif args.qy:
        device_str = "cuda"
303
304
305
306
307
308
    elif args.metax:
        device_str = "cuda"
    elif args.moore:
        device_str = "musa"
    elif args.iluvatar:
        device_str = "cuda"
309
310
    elif args.cambricon:
        device_str = "mlu"
wooway777's avatar
wooway777 committed
311
312
    elif args.ali:
        device_str = "cuda"
313
314
    elif args.hygon:
        device_str = "cuda"
315
316
    else:
        print(
317
            "Usage:  python examples/jiuge.py [--cpu | --nvidia | --qy | --metax | --moore | --iluvatar | --cambricon | --ali | --hygon] --model_path=<path/to/model_dir>\n"
pengcheng888's avatar
pengcheng888 committed
318
            "such as, python examples/jiuge.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
319
320
321
        )
        sys.exit(1)
    prompts = [args.prompt for _ in range(args.batch_size)]
322
    _PAGED_KV_BLOCK_SIZE = args.paged_kv_block_size
323
324
325
326

    model_path = args.model_path
    max_new_tokens = args.max_new_tokens
    backend = args.backend
Your Name's avatar
Your Name committed
327
    tp = args.tp
328
    enable_paged_attn = args.enable_paged_attn
329
    enable_graph = args.enable_graph
330
331
332
    if backend != "cpp":
        raise ValueError(f"Unsupported backend: {backend}.")

333
334
335
336
337
338
339
    infini_device = infinicore.device(device_str, 0)

    test(
        prompts,
        model_path,
        max_new_tokens,
        infini_device=infini_device,
Your Name's avatar
Your Name committed
340
        tp=tp,
341
        enable_paged_attn=enable_paged_attn,
342
        enable_graph=enable_graph,
343
344
345
        top_k=args.top_k,
        top_p=args.top_p,
        temperature=args.temperature,
346
        attn_backend=args.attn,
347
    )