llama.py 5.87 KB
Newer Older
1
2
3
4
5
6
import infinicore
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import get_model_state_dict
import infinilm
import argparse
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import time
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))


def get_args():
    parser = argparse.ArgumentParser(description="run Llama args")

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Run cpu test",
    )
    parser.add_argument(
        "--nvidia",
        action="store_true",
        help="Run nvidia test",
    )
    parser.add_argument(
        "--metax",
        action="store_true",
        help="Run metax test",
    )
32
33
34
35
36
37
38
39
40
41
    parser.add_argument(
        "--moore",
        action="store_true",
        help="Run moore test",
    )
    parser.add_argument(
        "--iluvatar",
        action="store_true",
        help="Run iluvatar test",
    )
42
43
44
45
46
47
48
49
50
51
52
53
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="model_path",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=100,
        help="max_new_tokens",
    )
54
55
56
57
58
59
    parser.add_argument(
        "--backend",
        type=str,
        default="python",
        help="python or cpp model",
    )
PanZezhong's avatar
PanZezhong committed
60
61
62
63
64
65
66
67
68
69
70
71
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="number of prompts in a batch",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="How are you",
        help="input prompt",
    )
72

73
74
75
    return parser.parse_args()


76
def test(
PanZezhong's avatar
PanZezhong committed
77
    prompts: str | list[str],
78
79
80
81
    model_path,
    max_new_tokens=100,
    infini_device=infinicore.device("cpu", 0),
):
82
    model_path = os.path.expanduser(model_path)
83
84
85
    # ---------------------------------------------------------------------------- #
    #                        创建模型,
    # ---------------------------------------------------------------------------- #
86
    model = infinilm.AutoLlamaModel.from_pretrained(
87
88
        model_path,
        device=infini_device,
89
90
91
92
93
94
95
96
    )

    # ---------------------------------------------------------------------------- #
    #                        加载权重
    # ---------------------------------------------------------------------------- #
    model_param_infini = get_model_state_dict(
        model_path,
        device=infini_device,
97
        dtype=model.config.dtype,
98
99
    )

100
    model.load_state_dict(model_param_infini, strict=True)
101
102
103
104

    # ---------------------------------------------------------------------------- #
    #                        创建 tokenizer
    # ---------------------------------------------------------------------------- #
105
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
106

107
    if "llama" == model.config.model_type:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        backend = getattr(tokenizer, "backend_tokenizer", None)
        target = getattr(backend, "_tokenizer", backend)
        norm = getattr(target, "normalizer", None)
        dec = getattr(target, "decoder", None)
        sn = repr(norm)[:800] if norm is not None else ""
        sd = repr(dec)[:800] if dec is not None else ""
        has_prepend = "Prepend" in sn
        has_strip = "Strip" in sd
        if has_prepend and has_strip:
            target.decoder = _dec.Sequence(
                [
                    _dec.Replace("▁", " "),
                    _dec.ByteFallback(),
                    _dec.Fuse(),
                ]
            )
124
    else:
125
        raise ValueError(f"Unsupported model type: {model.config.model_type}")
126
127
128
129

    # ---------------------------------------------------------------------------- #
    #                        token编码
    # ---------------------------------------------------------------------------- #
130
    # prompt = "山东最高的山是?"
PanZezhong's avatar
PanZezhong committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    if isinstance(prompts, str):
        prompts = [prompts]
    input_contents = [
        tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": prompt}],
            add_generation_prompt=True,
            tokenize=False,
        )
        for prompt in prompts
    ]
    print(input_contents[0], end="", flush=True)
    input_ids_list = tokenizer.batch_encode_plus(input_contents)[
        "input_ids"
    ]  # List: [[1, 1128, 526, 366, 29892]]
145
146
147
148
149
150
151

    # ---------------------------------------------------------------------------- #
    #                        自回归生成
    # ---------------------------------------------------------------------------- #
    input_ids_infini = infinicore.from_list(input_ids_list)

    t1 = time.time()
152
    print("=================== start generate ====================")
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    model.generate(
        input_ids_infini,
        max_new_tokens=max_new_tokens,
        tokenizer=tokenizer,
    )
    t2 = time.time()

    print(
        f"total_time: {round((t2 - t1) * 1000, 2)} ms",
    )


if __name__ == "__main__":
    args = get_args()
    print(args)

    # Parse command line arguments
170
    device_str = "cpu"
171
    if args.cpu:
172
        device_str = "cpu"
173
    elif args.nvidia:
174
        device_str = "cuda"
175
    elif args.metax:
176
        device_str = "cuda"
177
    elif args.moore:
178
        device_str = "musa"
179
    elif args.iluvatar:
180
        device_str = "cuda"
181
182
    else:
        print(
183
184
            "Usage:  python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>\n"
            "such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
185
186
        )
        sys.exit(1)
PanZezhong's avatar
PanZezhong committed
187
    prompts = [args.prompt for _ in range(args.batch_size)]
188
189
190

    model_path = args.model_path
    max_new_tokens = args.max_new_tokens
191
192
    backend = args.backend

193
194
195
    if backend != "python":
        raise ValueError(f"Unsupported backend: {backend}.")

196
    infini_device = infinicore.device(device_str, 0)
197

198
    test(
PanZezhong's avatar
PanZezhong committed
199
        prompts,
200
201
202
203
        model_path,
        max_new_tokens,
        infini_device=infini_device,
    )