llama.py 5.88 KB
Newer Older
1
2
3
4
5
6
import infinicore
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import get_model_state_dict
import infinilm
import argparse
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
import time
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))


def get_args():
    parser = argparse.ArgumentParser(description="run Llama args")

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Run cpu test",
    )
    parser.add_argument(
        "--nvidia",
        action="store_true",
        help="Run nvidia test",
    )
    parser.add_argument(
        "--metax",
        action="store_true",
        help="Run metax test",
    )
32
33
34
35
36
37
38
39
40
41
    parser.add_argument(
        "--moore",
        action="store_true",
        help="Run moore test",
    )
    parser.add_argument(
        "--iluvatar",
        action="store_true",
        help="Run iluvatar test",
    )
42
43
44
45
46
47
48
49
50
51
52
53
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="model_path",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=100,
        help="max_new_tokens",
    )
54
55
56
57
58
59
    parser.add_argument(
        "--backend",
        type=str,
        default="python",
        help="python or cpp model",
    )
60
61
62
63
64
65
    parser.add_argument(
        "--dtype",
        type=str,
        default="float32",
        help="float32, float16, bfloat16",
    )
66
67
68
    return parser.parse_args()


69
70
71
72
73
74
75
76
def test(
    prompt,
    model_path,
    max_new_tokens=100,
    infini_dtype=infinicore.bfloat16,
    infini_device=infinicore.device("cpu", 0),
    backend="python",
):
77
78
79
    # ---------------------------------------------------------------------------- #
    #                        创建模型,
    # ---------------------------------------------------------------------------- #
80
81
    model = infinilm.AutoLlamaModel.from_pretrained(
        model_path, device=infini_device, dtype=infini_dtype, backend=backend
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    )

    # ---------------------------------------------------------------------------- #
    #                        加载权重
    # ---------------------------------------------------------------------------- #
    model_param_infini = get_model_state_dict(
        model_path,
        device=infini_device,
        dtype=infini_dtype,
    )

    model.load_state_dict(model_param_infini)

    config = model.config

    # ---------------------------------------------------------------------------- #
    #                        创建 tokenizer
    # ---------------------------------------------------------------------------- #
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    if "llama" == config.model_type:
        backend = getattr(tokenizer, "backend_tokenizer", None)
        target = getattr(backend, "_tokenizer", backend)
        norm = getattr(target, "normalizer", None)
        dec = getattr(target, "decoder", None)
        sn = repr(norm)[:800] if norm is not None else ""
        sd = repr(dec)[:800] if dec is not None else ""
        has_prepend = "Prepend" in sn
        has_strip = "Strip" in sd
        if has_prepend and has_strip:
            target.decoder = _dec.Sequence(
                [
                    _dec.Replace("▁", " "),
                    _dec.ByteFallback(),
                    _dec.Fuse(),
                ]
            )
119
120
    else:
        raise ValueError(f"Unsupported model type: {config.model_type}")
121
122
123
124

    # ---------------------------------------------------------------------------- #
    #                        token编码
    # ---------------------------------------------------------------------------- #
125
    # prompt = "山东最高的山是?"
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    input_content = tokenizer.apply_chat_template(
        conversation=[{"role": "user", "content": prompt}],
        add_generation_prompt=True,
        tokenize=False,
    )
    print(input_content, end="", flush=True)
    input_ids = tokenizer.encode(input_content)

    # ---------------------------------------------------------------------------- #
    #                        自回归生成
    # ---------------------------------------------------------------------------- #
    input_ids_list = [input_ids]  # List: [[1, 1128, 526, 366, 29892]]
    input_ids_infini = infinicore.from_list(input_ids_list)

    t1 = time.time()
141
    print("=================== start generate ====================")
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    model.generate(
        input_ids_infini,
        max_new_tokens=max_new_tokens,
        device=infini_device,
        tokenizer=tokenizer,
        config=config,
    )
    t2 = time.time()

    print(
        f"total_time: {round((t2 - t1) * 1000, 2)} ms",
    )


if __name__ == "__main__":
    args = get_args()
    print(args)

    # Parse command line arguments
161
    device_str = "cpu"
162
    if args.cpu:
163
        device_str = "cpu"
164
    elif args.nvidia:
165
        device_str = "cuda"
166
    elif args.metax:
167
        device_str = "cuda"
168
    elif args.moore:
169
        device_str = "musa"
170
    elif args.iluvatar:
171
        device_str = "cuda"
172
173
    else:
        print(
174
175
            "Usage:  python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>\n"
            "such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
176
177
        )
        sys.exit(1)
178
    prompt = "How are you"
179
180
181

    model_path = args.model_path
    max_new_tokens = args.max_new_tokens
182
183
184
    backend = args.backend

    infini_device = infinicore.device(device_str, 0)
185
186
187
188
189
190
191
192
    if args.dtype == "float32":
        infini_dtype = infinicore.float32
    elif args.dtype == "bfloat16":
        infini_dtype = infinicore.bfloat16
    elif args.dtype == "float16":
        infini_dtype = infinicore.float16
    else:
        raise ValueError(f"Unsupported dtype: {args.dtype}")
193

194
195
196
197
198
199
200
201
    test(
        prompt,
        model_path,
        max_new_tokens,
        infini_device=infini_device,
        infini_dtype=infini_dtype,
        backend=backend,
    )