txtgen.py 6.44 KB
Newer Older
Attila Dusnoki's avatar
Attila Dusnoki committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################

from argparse import ArgumentParser
from transformers import LlamaTokenizer
import numpy as np
import migraphx as mgx
import os
import sys
import time
from functools import wraps


# measurement helper
def measure(fn):
    @wraps(fn)
    def measure_ms(*args, **kwargs):
        start_time = time.perf_counter_ns()
        result = fn(*args, **kwargs)
        end_time = time.perf_counter_ns()
        print(
            f"Elapsed time for {fn.__name__}: {(end_time - start_time) * 1e-6:.4f} ms\n"
        )
        return result

    return measure_ms


def get_args():
    parser = ArgumentParser()
    parser.add_argument(
        "-p",
        "--prompt",
        type=str,
        required=True,
        help="Input prompt",
    )

    parser.add_argument(
        "-l",
        "--log-process",
        action="store_true",
        help="Print the current state of transcribing.",
    )

    parser.add_argument("-s",
                        "--max-seq-len",
                        type=int,
                        choices=[256, 512, 1024, 2048, 4096],
                        default=1024,
                        help="Max sequence length the model can handle")

    return parser.parse_args()


class Llama2MGX():
    def __init__(self, max_seq_len=1024):
        model_id = "meta-llama/Llama-2-7b-chat-hf"
        self.max_seq_len = max_seq_len

        print("Load mgx model")
        self.model = Llama2MGX.load_mgx_model(
            max_seq_len, {
                "input_ids": [1, max_seq_len],
                "attention_mask": [1, max_seq_len],
                "position_ids": [1, max_seq_len]
            })
        print(f"Load AutoTokenizer model from {model_id}")
        self.tokenizer = LlamaTokenizer.from_pretrained(model_id)

    @staticmethod
    @measure
    def load_mgx_model(max_seq_len, shapes):
        file = "models/llama-2-7b-chat-hf/model"
        print(f"Loading {max_seq_len} seq-len version model from {file}")
        if os.path.isfile(f"{file}-{max_seq_len}.mxr"):
            print("Found mxr, loading it...")
            model = mgx.load(f"{file}-{max_seq_len}.mxr", format="msgpack")
        elif os.path.isfile(f"{file}.onnx"):
            print("Parsing from onnx file...")
            model = mgx.parse_onnx(f"{file}.onnx", map_input_dims=shapes)
            model.compile(mgx.get_target("gpu"))
            print("Saving model to mxr file...")
            mgx.save(model, f"{file}-{max_seq_len}.mxr", format="msgpack")
        else:
            print("No model found. Please download it and re-try.")
            sys.exit(1)
        return model

    @measure
    def tokenize(self, prompt):
        return self.tokenizer(prompt, return_tensors="np").input_ids

    @measure
    def get_input_features_for_input_ids(self, input_ids):
        input_ids_len = len(input_ids[0])
        padding_len = self.max_seq_len - input_ids_len
        input_ids = np.hstack([input_ids,
                               np.zeros((1, padding_len))]).astype(np.int64)
        # 0 masked | 1 un-masked
        attention_mask = np.array([1] * input_ids_len +
                                  [0] * padding_len).astype(np.int64)
        attention_mask = attention_mask[np.newaxis]
        position_ids = np.arange(0, self.max_seq_len, dtype=np.int64)
        position_ids = position_ids[np.newaxis]

        return (input_ids, attention_mask, position_ids)

    @measure
    def decode_step(self, input_ids, attention_mask, position_ids):
        return np.array(
            self.model.run({
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "position_ids": position_ids
            })[0])

    @measure
    def decode_tokens(self, generated_tokens, skip_special_tokens=True):
        return ''.join(
            self.tokenizer.decode(generated_tokens,
                                  skip_special_tokens=skip_special_tokens))

    @measure
    def generate(self, input_ids, log_process=False):
        start_timestep = len(input_ids[0]) - 1
        end_timestep = self.max_seq_len

        input_ids, attention_mask, position_ids = self.get_input_features_for_input_ids(
            input_ids)
        print("Generating response...")
        for timestep in range(start_timestep, self.max_seq_len):
            # get logits for the current timestep
            logits = self.decode_step(input_ids, attention_mask, position_ids)
            # greedily get the highest probable token
            new_token = np.argmax(logits[0][timestep])

            # add it to the tokens and unmask it
            input_ids[0][timestep + 1] = new_token
            attention_mask[0][timestep + 1] = 1

            if log_process:
                print(self.decode_tokens(input_ids[0][:timestep + 2]))

            if new_token == self.tokenizer.eos_token_id:
                end_timestep = timestep + 1
                break

        return self.decode_tokens(input_ids[0][:end_timestep + 1])


if __name__ == "__main__":
    args = get_args()
    llama = Llama2MGX(args.max_seq_len)

    print(f"Call tokenizer with \"{args.prompt}\"")

    input_ids = llama.tokenize(args.prompt)
    result = llama.generate(input_ids, log_process=args.log_process)

    print(f"Result text: {result}")