turbomind.py 10.1 KB
Newer Older
q.yao's avatar
q.yao committed
1
2
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
3
import sys
q.yao's avatar
q.yao committed
4
5
from queue import Queue
from threading import Thread
6
7
from typing import Iterable, List

q.yao's avatar
q.yao committed
8
import numpy as np
9
import torch
q.yao's avatar
q.yao committed
10
11
from torch.nn.utils.rnn import pad_sequence

12
13
import lmdeploy

q.yao's avatar
q.yao committed
14
15
16
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
17
import _turbomind as _tm  # noqa: E402
q.yao's avatar
q.yao committed
18
19
20


def _stop_words(stop_words: List[int]):
lvhan028's avatar
lvhan028 committed
21
    """return list of stop-words to numpy.ndarray."""
q.yao's avatar
q.yao committed
22
23
24
    if stop_words is None:
        return None
    assert isinstance(stop_words, List) and \
25
26
27
           all(isinstance(elem, int) for elem in stop_words), \
           f'stop_words must be a list but got {type(stop_words)}'

q.yao's avatar
q.yao committed
28
29
30
31
32
33
34
35
36
    # each id in stop_words represents a stop word
    # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
    # detailed explanation about fastertransformer's stop_words
    stop_word_offsets = range(1, len(stop_words) + 1)
    stop_words = np.array([[stop_words, stop_word_offsets]]).astype(np.int32)
    return stop_words


def _np_dict_to_tm_dict(np_dict: dict):
lvhan028's avatar
lvhan028 committed
37
    """map numpy.ndarray to turbomind's tensor."""
q.yao's avatar
q.yao committed
38
39
40
41
42
43
44
45
    ret = _tm.TensorMap()
    for k, v in np_dict.items():
        ret[k] = _tm.from_dlpack(v)

    return ret


def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):
lvhan028's avatar
lvhan028 committed
46
    """map turbomind's tensor to torch's tensor."""
q.yao's avatar
q.yao committed
47
48
49
50
51
52
53
54
55
56
    ret = dict()
    for k, v in tm_dict.items():
        if v.type == _tm.DataType.TYPE_UINT32:
            v = v.view(_tm.DataType.TYPE_INT32)
        ret[k] = torch.from_dlpack(v)

    return ret


class TurboMind:
lvhan028's avatar
lvhan028 committed
57
58
59
60
61
62
63
64
65
66
67
68
69
    """LMDeploy's inference engine.

    Args:
        model_path (str): the path of turbomind's model
        data_type (str): the data type
        session_len (int): the max length of a session
        eos_id (int): eos token id
        stop_words (List[int]): token ids of stop-words
        device_id (int): the id of a gpu card
        node_id (int): the id of a node
        device_num (int): the number of gpu cards
        node_num (int): the number of node
    """
q.yao's avatar
q.yao committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def __init__(self,
                 model_path: str,
                 data_type: str = 'fp16',
                 session_len: int = 2048,
                 eos_id: int = 2,
                 stop_words: List[int] = None,
                 device_id: int = 0,
                 node_id: int = 0,
                 device_num: int = 1,
                 node_num: int = 1):
        self.eos_id = eos_id

        # create model instance
        self.node_id = node_id
        self.node_num = node_num
        self.gpu_count = device_num
        self.device_id = device_id
        self.world_size = self.node_num * self.gpu_count
        self.rank = self.node_id * self.gpu_count + self.device_id
        self.session_len = session_len

        weight_dir = osp.join(model_path, 'triton_models', 'weights')
        model = _tm.AbstractTransformerModel.create_llama_model(
            weight_dir, tensor_para_size=self.gpu_count, data_type=data_type)
        model.create_shared_weights(self.device_id, self.rank)
        self.model = model
        self.stop_words = _stop_words(stop_words)

q.yao's avatar
q.yao committed
99
    def create_instance(self, cuda_stream_id=0):
lvhan028's avatar
lvhan028 committed
100
101
102
103
104
105
106
        """Create a turbomind instance.

        Args:
            cuda_stream_id(int): identity of a cuda stream
        Returns:
            TurboMindInstance: an instance of turbomind
        """
q.yao's avatar
q.yao committed
107
        return TurboMindInstance(self, cuda_stream_id)
q.yao's avatar
q.yao committed
108
109
110


class TurboMindInstance:
lvhan028's avatar
lvhan028 committed
111
112
113
114
115
116
    """Instance of TurboMind.

    Args:
        tm_model (str): turbomind's model path
        cuda_stream_id(int): identity of a cuda stream
    """
q.yao's avatar
q.yao committed
117

q.yao's avatar
q.yao committed
118
    def __init__(self, tm_model, cuda_stream_id=0):
q.yao's avatar
q.yao committed
119
120
121
122
123
124
125
        self.tm_model = tm_model

        self.device_id = tm_model.device_id
        self.rank = tm_model.rank
        self.stop_words = tm_model.stop_words
        self.eos_id = tm_model.eos_id
        self.session_len = tm_model.session_len
q.yao's avatar
q.yao committed
126
        self.cuda_stream_id = cuda_stream_id
q.yao's avatar
q.yao committed
127
128
129
130
131
132
133
134

        # create instance
        model = tm_model.model
        nccl_params = model.create_nccl_params(tm_model.node_id)
        custom_comms = model.create_custom_comms(tm_model.world_size)
        instance_comm = model.create_instance_comm(tm_model.gpu_count)

        model_inst = model.create_model_instance(self.device_id, self.rank,
q.yao's avatar
q.yao committed
135
136
137
                                                 self.cuda_stream_id,
                                                 nccl_params, custom_comms[0])
        # model_inst.register_callback(self._forward_callback)
q.yao's avatar
q.yao committed
138
139
        self.model_inst = model_inst
        self.instance_comm = instance_comm
q.yao's avatar
q.yao committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
        self.que = Queue()
        self.thread = None

    def _forward_callback(self, result, ctx):
        self.que.put((False, result))

    def _forward_thread(self, inputs):

        def _func():
            output = self.model_inst.forward(inputs, self.instance_comm)
            self.que.put((True, output))

        self.thread = Thread(target=_func)
        self.thread.start()
q.yao's avatar
q.yao committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167

    def stream_infer(self,
                     session_id,
                     input_ids,
                     request_output_len: int = 512,
                     sequence_start: bool = True,
                     sequence_end: bool = False,
                     step=1,
                     stop=False,
                     top_p=0.8,
                     top_k=40,
                     temperature=0.8,
                     repetition_penalty=1.05,
                     ignore_eos=False,
q.yao's avatar
q.yao committed
168
169
                     random_seed=None,
                     stream_output=False):
lvhan028's avatar
lvhan028 committed
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        """Perform model inference.

        Args:
            session_id (int): the id of a session
            input_ids (numpy.ndarray): the token ids of a prompt
            request_output_len (int): the max number of to-be-generated tokens
            sequence_start (bool): indicator for starting a sequence
            sequence_end (bool): indicator for ending a sequence
            step (int): the offset of the k/v cache
            stop (bool): indicator for cancelling the session
            top_p (float): If set to float < 1, only the smallest set of most
              probable tokens with probabilities that add up to top_p or higher
            are kept for generation.
            top_k (int): The number of the highest probability vocabulary
              tokens to keep for top-k-filtering
            temperature (float): to modulate the next token probability
            repetition_penalty (float): The parameter for repetition penalty.
              1.0 means no penalty
            ignore_eos (bool): indicator for ignoring eos
            random_seed (int): seed used by sampling
            stream_output (bool): indicator for stream output
        """
q.yao's avatar
q.yao committed
192
193
        if stream_output:
            self.model_inst.register_callback(self._forward_callback)
q.yao's avatar
q.yao committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

        if len(input_ids) == 0:
            input_ids = []
        if isinstance(input_ids[0], int):
            input_ids = [input_ids]

        batch_size = len(input_ids)

        def _broadcast_np(data, dtype, shape=(batch_size, )):
            if isinstance(data, Iterable):
                assert len(data) == batch_size
                return data

            return np.full(shape, data, dtype=dtype)

        input_ids = [torch.IntTensor(ids) for ids in input_ids]
        input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
211
212
213
        input_ids = pad_sequence(input_ids,
                                 batch_first=True,
                                 padding_value=self.eos_id)
q.yao's avatar
q.yao committed
214
215
216
217
218

        if isinstance(session_id, int):
            session_id = [session_id]
        assert len(session_id) == batch_size

q.yao's avatar
q.yao committed
219
220
        step = _broadcast_np(step, np.int32)

q.yao's avatar
q.yao committed
221
222
223
        inputs = dict(
            input_ids=input_ids,
            input_lengths=input_lengths,
224
225
226
            request_output_len=np.full(input_lengths.shape,
                                       request_output_len,
                                       dtype=np.uint32),
q.yao's avatar
q.yao committed
227
228
229
230
            runtime_top_k=_broadcast_np(top_k, np.uint32),
            runtime_top_p=_broadcast_np(top_p, np.float32),
            temperature=_broadcast_np(temperature, np.float32),
            repetition_penalty=_broadcast_np(repetition_penalty, np.float32),
q.yao's avatar
q.yao committed
231
            step=step,
q.yao's avatar
q.yao committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

            # session input
            session_len=self.session_len *
            np.ones([
                batch_size,
            ], dtype=np.uint32),
            START=_broadcast_np((1 if sequence_start else 0), np.int32),
            END=_broadcast_np((1 if sequence_end else 0), np.int32),
            CORRID=np.array(session_id, dtype=np.uint64),
            STOP=_broadcast_np((1 if stop else 0), np.int32))

        if ignore_eos:
            stop_words = None
            bad_words = torch.tensor([[[self.eos_id], [1]]], dtype=torch.int32)
        else:
            stop_words = self.stop_words
            bad_words = None

        if stop_words is not None:
            inputs['stop_words_list'] = stop_words
        if bad_words is not None:
            inputs['bad_words_list'] = bad_words

        if random_seed is not None:
            inputs['random_seed'] = _broadcast_np(random_seed, np.uint64)
        tm_inputs = _np_dict_to_tm_dict(inputs)

q.yao's avatar
q.yao committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
        # start forward thread
        self._forward_thread(tm_inputs)

        seq_start = input_lengths + input_lengths.new_tensor(step)

        # generator
        while True:
            while self.que.qsize() > 1:
                self.que.get()

            finish, tm_outputs = self.que.get()

            outputs = _tm_dict_to_torch_dict(tm_outputs)

            output_ids = outputs['output_ids'][:, 0, :]
            sequence_length = outputs['sequence_length'].long()[:, 0].cpu()
            output_ids = [
                output_id[s:l] for output_id, s, l in zip(
                    output_ids, seq_start, sequence_length)
            ]
            sequence_length -= seq_start.to(sequence_length.device)
            yield [(output, l.item())
                   for output, l in zip(output_ids, sequence_length)]

            if finish:
                while self.que.qsize() > 0:
                    self.que.get()
                self.thread.join()
                break
q.yao's avatar
q.yao committed
288

q.yao's avatar
q.yao committed
289
290
        if stream_output:
            self.model_inst.unregister_callback()