llama.py 6.37 KB
Newer Older
1
2
3
4
5
6
import infinicore
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import get_model_state_dict
import infinilm
import argparse
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import time
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))


def get_args():
    parser = argparse.ArgumentParser(description="run Llama args")

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Run cpu test",
    )
    parser.add_argument(
        "--nvidia",
        action="store_true",
        help="Run nvidia test",
    )
    parser.add_argument(
        "--metax",
        action="store_true",
        help="Run metax test",
    )
32
33
34
35
36
37
38
39
40
41
    parser.add_argument(
        "--moore",
        action="store_true",
        help="Run moore test",
    )
    parser.add_argument(
        "--iluvatar",
        action="store_true",
        help="Run iluvatar test",
    )
42
43
44
45
46
47
48
49
50
51
52
53
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="model_path",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=100,
        help="max_new_tokens",
    )
54
55
56
57
58
59
    parser.add_argument(
        "--backend",
        type=str,
        default="python",
        help="python or cpp model",
    )
60
61
62
63
64
65
    parser.add_argument(
        "--dtype",
        type=str,
        default="float32",
        help="float32, float16, bfloat16",
    )
PanZezhong's avatar
PanZezhong committed
66
67
68
69
70
71
72
73
74
75
76
77
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="number of prompts in a batch",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="How are you",
        help="input prompt",
    )
78

79
80
81
    return parser.parse_args()


82
def test(
PanZezhong's avatar
PanZezhong committed
83
    prompts: str | list[str],
84
85
86
87
88
89
    model_path,
    max_new_tokens=100,
    infini_dtype=infinicore.bfloat16,
    infini_device=infinicore.device("cpu", 0),
    backend="python",
):
90
    model_path = os.path.expanduser(model_path)
91
92
93
    # ---------------------------------------------------------------------------- #
    #                        创建模型,
    # ---------------------------------------------------------------------------- #
94
    model = infinilm.AutoLlamaModel.from_pretrained(
95
96
97
98
        model_path,
        device=infini_device,
        dtype=infini_dtype,
        backend=backend,
99
100
101
102
103
104
105
106
107
108
109
    )

    # ---------------------------------------------------------------------------- #
    #                        加载权重
    # ---------------------------------------------------------------------------- #
    model_param_infini = get_model_state_dict(
        model_path,
        device=infini_device,
        dtype=infini_dtype,
    )

110
    model.load_state_dict(model_param_infini, strict=True)
111
112
113
114

    # ---------------------------------------------------------------------------- #
    #                        创建 tokenizer
    # ---------------------------------------------------------------------------- #
115
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
116

117
    if "llama" == model.config.model_type:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        backend = getattr(tokenizer, "backend_tokenizer", None)
        target = getattr(backend, "_tokenizer", backend)
        norm = getattr(target, "normalizer", None)
        dec = getattr(target, "decoder", None)
        sn = repr(norm)[:800] if norm is not None else ""
        sd = repr(dec)[:800] if dec is not None else ""
        has_prepend = "Prepend" in sn
        has_strip = "Strip" in sd
        if has_prepend and has_strip:
            target.decoder = _dec.Sequence(
                [
                    _dec.Replace("▁", " "),
                    _dec.ByteFallback(),
                    _dec.Fuse(),
                ]
            )
134
    else:
135
        raise ValueError(f"Unsupported model type: {model.config.model_type}")
136
137
138
139

    # ---------------------------------------------------------------------------- #
    #                        token编码
    # ---------------------------------------------------------------------------- #
140
    # prompt = "山东最高的山是?"
PanZezhong's avatar
PanZezhong committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    if isinstance(prompts, str):
        prompts = [prompts]
    input_contents = [
        tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": prompt}],
            add_generation_prompt=True,
            tokenize=False,
        )
        for prompt in prompts
    ]
    print(input_contents[0], end="", flush=True)
    input_ids_list = tokenizer.batch_encode_plus(input_contents)[
        "input_ids"
    ]  # List: [[1, 1128, 526, 366, 29892]]
155
156
157
158
159
160
161

    # ---------------------------------------------------------------------------- #
    #                        自回归生成
    # ---------------------------------------------------------------------------- #
    input_ids_infini = infinicore.from_list(input_ids_list)

    t1 = time.time()
162
    print("=================== start generate ====================")
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    model.generate(
        input_ids_infini,
        max_new_tokens=max_new_tokens,
        tokenizer=tokenizer,
    )
    t2 = time.time()

    print(
        f"total_time: {round((t2 - t1) * 1000, 2)} ms",
    )


if __name__ == "__main__":
    args = get_args()
    print(args)

    # Parse command line arguments
180
    device_str = "cpu"
181
    if args.cpu:
182
        device_str = "cpu"
183
    elif args.nvidia:
184
        device_str = "cuda"
185
    elif args.metax:
186
        device_str = "cuda"
187
    elif args.moore:
188
        device_str = "musa"
189
    elif args.iluvatar:
190
        device_str = "cuda"
191
192
    else:
        print(
193
194
            "Usage:  python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>\n"
            "such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
195
196
        )
        sys.exit(1)
PanZezhong's avatar
PanZezhong committed
197
    prompts = [args.prompt for _ in range(args.batch_size)]
198
199
200

    model_path = args.model_path
    max_new_tokens = args.max_new_tokens
201
202
203
    backend = args.backend

    infini_device = infinicore.device(device_str, 0)
204
205
206
207
208
209
210
211
    if args.dtype == "float32":
        infini_dtype = infinicore.float32
    elif args.dtype == "bfloat16":
        infini_dtype = infinicore.bfloat16
    elif args.dtype == "float16":
        infini_dtype = infinicore.float16
    else:
        raise ValueError(f"Unsupported dtype: {args.dtype}")
212

213
    test(
PanZezhong's avatar
PanZezhong committed
214
        prompts,
215
216
217
218
219
220
        model_path,
        max_new_tokens,
        infini_device=infini_device,
        infini_dtype=infini_dtype,
        backend=backend,
    )