llama.py 6.61 KB
Newer Older
1
2
3
4
5
import infinicore
from transformers import AutoTokenizer
from tokenizers import decoders as _dec
from infinilm.modeling_utils import get_model_state_dict
import infinilm
6
from infinilm.distributed import DistConfig
7
import argparse
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import sys
import time
import os

sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../python"))


def get_args():
    parser = argparse.ArgumentParser(description="run Llama args")

    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Run cpu test",
    )
    parser.add_argument(
        "--nvidia",
        action="store_true",
        help="Run nvidia test",
    )
    parser.add_argument(
        "--metax",
        action="store_true",
        help="Run metax test",
    )
33
34
35
36
37
38
39
40
41
42
    parser.add_argument(
        "--moore",
        action="store_true",
        help="Run moore test",
    )
    parser.add_argument(
        "--iluvatar",
        action="store_true",
        help="Run iluvatar test",
    )
43
44
45
46
47
48
49
50
51
52
53
54
    parser.add_argument(
        "--model_path",
        type=str,
        required=True,
        help="model_path",
    )
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=100,
        help="max_new_tokens",
    )
55
56
57
58
59
60
    parser.add_argument(
        "--backend",
        type=str,
        default="python",
        help="python or cpp model",
    )
61
62
63
64
65
66
    parser.add_argument(
        "--dtype",
        type=str,
        default="float32",
        help="float32, float16, bfloat16",
    )
PanZezhong's avatar
PanZezhong committed
67
68
69
70
71
72
73
74
75
76
77
78
    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="number of prompts in a batch",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        default="How are you",
        help="input prompt",
    )
79
80
81
82
83
84
85
    parser.add_argument(
        "--tp",
        type=int,
        default=None,
        help="total rank for tensor parallel",
    )

86
87
88
    return parser.parse_args()


89
def test(
PanZezhong's avatar
PanZezhong committed
90
    prompts: str | list[str],
91
92
93
94
95
96
    model_path,
    max_new_tokens=100,
    infini_dtype=infinicore.bfloat16,
    infini_device=infinicore.device("cpu", 0),
    backend="python",
):
97
    model_path = os.path.expanduser(model_path)
98
99
100
    # ---------------------------------------------------------------------------- #
    #                        创建模型,
    # ---------------------------------------------------------------------------- #
101
    model = infinilm.AutoLlamaModel.from_pretrained(
102
103
104
105
106
        model_path,
        device=infini_device,
        dtype=infini_dtype,
        backend=backend,
        distributed_config=DistConfig(args.tp),
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    )

    # ---------------------------------------------------------------------------- #
    #                        加载权重
    # ---------------------------------------------------------------------------- #
    model_param_infini = get_model_state_dict(
        model_path,
        device=infini_device,
        dtype=infini_dtype,
    )

    model.load_state_dict(model_param_infini)

    # ---------------------------------------------------------------------------- #
    #                        创建 tokenizer
    # ---------------------------------------------------------------------------- #
123
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
124

125
    if "llama" == model.config.model_type:
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        backend = getattr(tokenizer, "backend_tokenizer", None)
        target = getattr(backend, "_tokenizer", backend)
        norm = getattr(target, "normalizer", None)
        dec = getattr(target, "decoder", None)
        sn = repr(norm)[:800] if norm is not None else ""
        sd = repr(dec)[:800] if dec is not None else ""
        has_prepend = "Prepend" in sn
        has_strip = "Strip" in sd
        if has_prepend and has_strip:
            target.decoder = _dec.Sequence(
                [
                    _dec.Replace("▁", " "),
                    _dec.ByteFallback(),
                    _dec.Fuse(),
                ]
            )
142
    else:
143
        raise ValueError(f"Unsupported model type: {model.config.model_type}")
144
145
146
147

    # ---------------------------------------------------------------------------- #
    #                        token编码
    # ---------------------------------------------------------------------------- #
148
    # prompt = "山东最高的山是?"
PanZezhong's avatar
PanZezhong committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    if isinstance(prompts, str):
        prompts = [prompts]
    input_contents = [
        tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": prompt}],
            add_generation_prompt=True,
            tokenize=False,
        )
        for prompt in prompts
    ]
    print(input_contents[0], end="", flush=True)
    input_ids_list = tokenizer.batch_encode_plus(input_contents)[
        "input_ids"
    ]  # List: [[1, 1128, 526, 366, 29892]]
163
164
165
166
167
168
169

    # ---------------------------------------------------------------------------- #
    #                        自回归生成
    # ---------------------------------------------------------------------------- #
    input_ids_infini = infinicore.from_list(input_ids_list)

    t1 = time.time()
170
    print("=================== start generate ====================")
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    model.generate(
        input_ids_infini,
        max_new_tokens=max_new_tokens,
        device=infini_device,
        tokenizer=tokenizer,
    )
    t2 = time.time()

    print(
        f"total_time: {round((t2 - t1) * 1000, 2)} ms",
    )


if __name__ == "__main__":
    args = get_args()
    print(args)

    # Parse command line arguments
189
    device_str = "cpu"
190
    if args.cpu:
191
        device_str = "cpu"
192
    elif args.nvidia:
193
        device_str = "cuda"
194
    elif args.metax:
195
        device_str = "cuda"
196
    elif args.moore:
197
        device_str = "musa"
198
    elif args.iluvatar:
199
        device_str = "cuda"
200
201
    else:
        print(
202
203
            "Usage:  python examples/llama.py [--cpu | --nvidia | --metax | --moore | --iluvatar] --model_path=<path/to/model_dir>\n"
            "such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
204
205
        )
        sys.exit(1)
PanZezhong's avatar
PanZezhong committed
206
    prompts = [args.prompt for _ in range(args.batch_size)]
207
208
209

    model_path = args.model_path
    max_new_tokens = args.max_new_tokens
210
211
212
    backend = args.backend

    infini_device = infinicore.device(device_str, 0)
213
214
215
216
217
218
219
220
    if args.dtype == "float32":
        infini_dtype = infinicore.float32
    elif args.dtype == "bfloat16":
        infini_dtype = infinicore.bfloat16
    elif args.dtype == "float16":
        infini_dtype = infinicore.float16
    else:
        raise ValueError(f"Unsupported dtype: {args.dtype}")
221

222
    test(
PanZezhong's avatar
PanZezhong committed
223
        prompts,
224
225
226
227
228
229
        model_path,
        max_new_tokens,
        infini_device=infini_device,
        infini_dtype=infini_dtype,
        backend=backend,
    )