turbomind.py 14.6 KB
Newer Older
q.yao's avatar
q.yao committed
1
# Copyright (c) OpenMMLab. All rights reserved.
AllentDan's avatar
AllentDan committed
2
import asyncio
q.yao's avatar
q.yao committed
3
import os.path as osp
4
import sys
q.yao's avatar
q.yao committed
5
6
from configparser import ConfigParser
from contextlib import contextmanager
q.yao's avatar
q.yao committed
7
8
from queue import Queue
from threading import Thread
9
10
from typing import Iterable, List

q.yao's avatar
q.yao committed
11
import numpy as np
12
import torch
q.yao's avatar
q.yao committed
13
14
from torch.nn.utils.rnn import pad_sequence

15
import lmdeploy
16
from lmdeploy.model import MODELS
17
from lmdeploy.turbomind import Tokenizer
18
from lmdeploy.utils import get_logger
19

q.yao's avatar
q.yao committed
20
21
22
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
23
import _turbomind as _tm  # noqa: E402
q.yao's avatar
q.yao committed
24
25


26
def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
lvhan028's avatar
lvhan028 committed
27
    """return list of stop-words to numpy.ndarray."""
q.yao's avatar
q.yao committed
28
29
30
    if stop_words is None:
        return None
    assert isinstance(stop_words, List) and \
31
           all(isinstance(elem, str) for elem in stop_words), \
32
           f'stop_words must be a list but got {type(stop_words)}'
33
34
35
    stop_words = [tokenizer.encode(stop_word)[-1] for stop_word in stop_words]
    assert isinstance(stop_words, List) and all(
        isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
q.yao's avatar
q.yao committed
36
37
38
39
40
41
42
43
44
    # each id in stop_words represents a stop word
    # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
    # detailed explanation about fastertransformer's stop_words
    stop_word_offsets = range(1, len(stop_words) + 1)
    stop_words = np.array([[stop_words, stop_word_offsets]]).astype(np.int32)
    return stop_words


def _np_dict_to_tm_dict(np_dict: dict):
lvhan028's avatar
lvhan028 committed
45
    """map numpy.ndarray to turbomind's tensor."""
q.yao's avatar
q.yao committed
46
47
48
49
50
51
52
53
    ret = _tm.TensorMap()
    for k, v in np_dict.items():
        ret[k] = _tm.from_dlpack(v)

    return ret


def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):
lvhan028's avatar
lvhan028 committed
54
    """map turbomind's tensor to torch's tensor."""
q.yao's avatar
q.yao committed
55
56
57
58
59
60
61
62
63
    ret = dict()
    for k, v in tm_dict.items():
        if v.type == _tm.DataType.TYPE_UINT32:
            v = v.view(_tm.DataType.TYPE_INT32)
        ret[k] = torch.from_dlpack(v)

    return ret


q.yao's avatar
q.yao committed
64
65
66
67
68
69
70
71
@contextmanager
def cuda_ctx(device_id):
    old_device = torch.cuda.current_device()
    torch.cuda.set_device(device_id)
    yield
    torch.cuda.set_device(old_device)


q.yao's avatar
q.yao committed
72
class TurboMind:
lvhan028's avatar
lvhan028 committed
73
74
75
76
77
    """LMDeploy's inference engine.

    Args:
        model_path (str): the path of turbomind's model
        eos_id (int): eos token id
78
        tp (int): tensor parallel
lvhan028's avatar
lvhan028 committed
79
    """
q.yao's avatar
q.yao committed
80

81
    def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1):
q.yao's avatar
q.yao committed
82
83
        self.eos_id = eos_id

q.yao's avatar
q.yao committed
84
85
86
87
88
        # TODO: support mpi
        node_id = 0
        node_num = 1

        # read meta from model path
89
        self.gpu_count = tp
q.yao's avatar
q.yao committed
90
        self.session_len = 2048
91
        data_type = 'fp16'
q.yao's avatar
q.yao committed
92
93
94
95
96
97
98
99
100
101
102
        ini_path = osp.join(model_path, 'triton_models/weights/config.ini')
        with open(ini_path, 'r') as f:
            parser = ConfigParser()
            parser.read_file(f)
            section_name = ''
            if 'turbomind' in parser:
                section_name = 'turbomind'
            elif 'llama' in parser:
                section_name = 'llama'

            if len(section_name) > 0:
103
                tp_cfg = parser.getint(section_name, 'tensor_para_size')
q.yao's avatar
q.yao committed
104
                self.session_len = parser.getint(section_name, 'session_len')
105
106
107
108
                if tp_cfg != 1 and tp_cfg != tp:
                    get_logger('turbomind').info(
                        f'found tp={tp_cfg} in config.ini.')
                    self.gpu_count = tp_cfg
109
            self.model_name = parser.get(section_name, 'model_name')
110
            data_type = parser.get(section_name, 'weight_type')
111
        model = MODELS.get(self.model_name)()
112
113
114
115
        tokenizer_model_path = osp.join(model_path, 'triton_models',
                                        'tokenizer')
        tokenizer = Tokenizer(tokenizer_model_path)
        self.stop_words = _stop_words(model.stop_words, tokenizer)
q.yao's avatar
q.yao committed
116
117

        # params
q.yao's avatar
q.yao committed
118
119
120
121
        self.node_id = node_id
        self.node_num = node_num
        self.world_size = self.node_num * self.gpu_count

q.yao's avatar
q.yao committed
122
        # create model
q.yao's avatar
q.yao committed
123
124
125
126
        weight_dir = osp.join(model_path, 'triton_models', 'weights')
        model = _tm.AbstractTransformerModel.create_llama_model(
            weight_dir, tensor_para_size=self.gpu_count, data_type=data_type)
        self.model = model
q.yao's avatar
q.yao committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        self.nccl_params = model.create_nccl_params(self.node_id)
        torch.cuda.synchronize()

        # create weight
        def _create_weight(device_id):
            with cuda_ctx(device_id):
                rank = self.node_id * self.gpu_count + device_id
                model.create_shared_weights(device_id, rank)

        threads = []
        for device_id in range(self.gpu_count):
            t = Thread(target=_create_weight, args=(device_id, ))
            t.start()
            threads.append(t)
        for t in threads:
            t.join()

q.yao's avatar
q.yao committed
144
    def create_instance(self, cuda_stream_id=0):
lvhan028's avatar
lvhan028 committed
145
146
147
148
149
150
151
        """Create a turbomind instance.

        Args:
            cuda_stream_id(int): identity of a cuda stream
        Returns:
            TurboMindInstance: an instance of turbomind
        """
q.yao's avatar
q.yao committed
152
        return TurboMindInstance(self, cuda_stream_id)
q.yao's avatar
q.yao committed
153
154
155


class TurboMindInstance:
lvhan028's avatar
lvhan028 committed
156
157
158
159
160
161
    """Instance of TurboMind.

    Args:
        tm_model (str): turbomind's model path
        cuda_stream_id(int): identity of a cuda stream
    """
q.yao's avatar
q.yao committed
162

q.yao's avatar
q.yao committed
163
    def __init__(self, tm_model, cuda_stream_id=0):
q.yao's avatar
q.yao committed
164
        self.tm_model = tm_model
q.yao's avatar
q.yao committed
165
166
167
168
        self.cuda_stream_id = cuda_stream_id

        self.node_id = tm_model.node_id
        self.gpu_count = tm_model.gpu_count
q.yao's avatar
q.yao committed
169
170

        self.stop_words = tm_model.stop_words
171
172
        self.stop_tokens = [] if self.stop_words is None else \
            self.stop_words.flatten().tolist()
q.yao's avatar
q.yao committed
173
174
175
        self.eos_id = tm_model.eos_id
        self.session_len = tm_model.session_len

q.yao's avatar
q.yao committed
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        self.nccl_params = tm_model.nccl_params
        self.instance_comm = tm_model.model.create_instance_comm(
            self.gpu_count)

        # create model instances
        model_insts = [None] * self.gpu_count
        threads = []
        for device_id in range(self.gpu_count):
            t = Thread(target=self._create_model_instance,
                       args=(device_id, model_insts))
            t.start()
            threads.append(t)
        for t in threads:
            t.join()

        self.model_insts = model_insts
q.yao's avatar
q.yao committed
192
        self.que = Queue()
q.yao's avatar
q.yao committed
193
194
195
196
197
198
199
200
        self.threads = [None] * self.gpu_count

    def _create_model_instance(self, device_id, model_insts):
        with cuda_ctx(device_id):
            rank = self.node_id * self.gpu_count + device_id
            model_inst = self.tm_model.model.create_model_instance(
                device_id, rank, self.cuda_stream_id, self.nccl_params)
            model_insts[device_id] = model_inst
q.yao's avatar
q.yao committed
201
202
203
204
205
206

    def _forward_callback(self, result, ctx):
        self.que.put((False, result))

    def _forward_thread(self, inputs):

q.yao's avatar
q.yao committed
207
208
209
210
211
212
        def _func(device_id, enque_output):
            with cuda_ctx(device_id):
                output = self.model_insts[device_id].forward(
                    inputs, self.instance_comm)
                if enque_output:
                    self.que.put((True, output))
q.yao's avatar
q.yao committed
213

q.yao's avatar
q.yao committed
214
215
216
217
        for device_id in range(self.gpu_count):
            t = Thread(target=_func, args=(device_id, device_id == 0))
            t.start()
            self.threads[device_id] = t
q.yao's avatar
q.yao committed
218

AllentDan's avatar
AllentDan committed
219
220
221
222
223
224
225
    async def async_stream_infer(self, *args, **kwargs):
        """Async wrapper of self.stream_infer."""
        for output in self.stream_infer(*args, **kwargs):
            # Allow the pipeline add new requests into the queue.
            await asyncio.sleep(0)
            yield output

q.yao's avatar
q.yao committed
226
227
228
229
230
231
232
233
234
235
236
    def stream_infer(self,
                     session_id,
                     input_ids,
                     request_output_len: int = 512,
                     sequence_start: bool = True,
                     sequence_end: bool = False,
                     step=1,
                     stop=False,
                     top_p=0.8,
                     top_k=40,
                     temperature=0.8,
237
                     repetition_penalty=1.0,
q.yao's avatar
q.yao committed
238
                     ignore_eos=False,
q.yao's avatar
q.yao committed
239
240
                     random_seed=None,
                     stream_output=False):
lvhan028's avatar
lvhan028 committed
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        """Perform model inference.

        Args:
            session_id (int): the id of a session
            input_ids (numpy.ndarray): the token ids of a prompt
            request_output_len (int): the max number of to-be-generated tokens
            sequence_start (bool): indicator for starting a sequence
            sequence_end (bool): indicator for ending a sequence
            step (int): the offset of the k/v cache
            stop (bool): indicator for cancelling the session
            top_p (float): If set to float < 1, only the smallest set of most
              probable tokens with probabilities that add up to top_p or higher
            are kept for generation.
            top_k (int): The number of the highest probability vocabulary
              tokens to keep for top-k-filtering
            temperature (float): to modulate the next token probability
            repetition_penalty (float): The parameter for repetition penalty.
              1.0 means no penalty
            ignore_eos (bool): indicator for ignoring eos
            random_seed (int): seed used by sampling
            stream_output (bool): indicator for stream output
        """
q.yao's avatar
q.yao committed
263
        if stream_output:
q.yao's avatar
q.yao committed
264
            self.model_insts[0].register_callback(self._forward_callback)
q.yao's avatar
q.yao committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        if len(input_ids) == 0:
            input_ids = []
        if isinstance(input_ids[0], int):
            input_ids = [input_ids]

        batch_size = len(input_ids)

        def _broadcast_np(data, dtype, shape=(batch_size, )):
            if isinstance(data, Iterable):
                assert len(data) == batch_size
                return data

            return np.full(shape, data, dtype=dtype)

        input_ids = [torch.IntTensor(ids) for ids in input_ids]
        input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
282
283
284
        input_ids = pad_sequence(input_ids,
                                 batch_first=True,
                                 padding_value=self.eos_id)
q.yao's avatar
q.yao committed
285
286
287
288
289

        if isinstance(session_id, int):
            session_id = [session_id]
        assert len(session_id) == batch_size

q.yao's avatar
q.yao committed
290
291
        step = _broadcast_np(step, np.int32)

q.yao's avatar
q.yao committed
292
293
294
        inputs = dict(
            input_ids=input_ids,
            input_lengths=input_lengths,
295
296
297
            request_output_len=np.full(input_lengths.shape,
                                       request_output_len,
                                       dtype=np.uint32),
q.yao's avatar
q.yao committed
298
299
300
301
            runtime_top_k=_broadcast_np(top_k, np.uint32),
            runtime_top_p=_broadcast_np(top_p, np.float32),
            temperature=_broadcast_np(temperature, np.float32),
            repetition_penalty=_broadcast_np(repetition_penalty, np.float32),
q.yao's avatar
q.yao committed
302
            step=step,
q.yao's avatar
q.yao committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329

            # session input
            session_len=self.session_len *
            np.ones([
                batch_size,
            ], dtype=np.uint32),
            START=_broadcast_np((1 if sequence_start else 0), np.int32),
            END=_broadcast_np((1 if sequence_end else 0), np.int32),
            CORRID=np.array(session_id, dtype=np.uint64),
            STOP=_broadcast_np((1 if stop else 0), np.int32))

        if ignore_eos:
            stop_words = None
            bad_words = torch.tensor([[[self.eos_id], [1]]], dtype=torch.int32)
        else:
            stop_words = self.stop_words
            bad_words = None

        if stop_words is not None:
            inputs['stop_words_list'] = stop_words
        if bad_words is not None:
            inputs['bad_words_list'] = bad_words

        if random_seed is not None:
            inputs['random_seed'] = _broadcast_np(random_seed, np.uint64)
        tm_inputs = _np_dict_to_tm_dict(inputs)

q.yao's avatar
q.yao committed
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        # start forward thread
        self._forward_thread(tm_inputs)

        seq_start = input_lengths + input_lengths.new_tensor(step)

        # generator
        while True:
            while self.que.qsize() > 1:
                self.que.get()

            finish, tm_outputs = self.que.get()

            outputs = _tm_dict_to_torch_dict(tm_outputs)

            output_ids = outputs['output_ids'][:, 0, :]
            sequence_length = outputs['sequence_length'].long()[:, 0].cpu()
            output_ids = [
                output_id[s:l] for output_id, s, l in zip(
                    output_ids, seq_start, sequence_length)
            ]
            sequence_length -= seq_start.to(sequence_length.device)
351
352
353
354

            outputs = []
            for output, len_ in zip(output_ids, sequence_length):
                output, len_ = output, len_.item()
355
                if len(output) > 0 and output[-1].item() == self.eos_id:
356
                    outputs.append((output[:-1], len_ - 1))
357
358
                elif len(output) > 0 and output[-1].item() in self.stop_tokens:
                    outputs.append((output[:-1], len_))
359
360
361
362
                else:
                    outputs.append((output, len_))

            yield outputs
q.yao's avatar
q.yao committed
363
364

            if finish:
q.yao's avatar
q.yao committed
365
366
                for t in self.threads:
                    t.join()
q.yao's avatar
q.yao committed
367
368
369
                while self.que.qsize() > 0:
                    self.que.get()
                break
q.yao's avatar
q.yao committed
370

q.yao's avatar
q.yao committed
371
        if stream_output:
q.yao's avatar
q.yao committed
372
            self.model_insts[0].unregister_callback()
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421

    def decode(self, input_ids):
        """Perform context decode on input tokens.

        Args:
            input_ids (numpy.ndarray): the batch of input token ids
        """

        if len(input_ids) == 0:
            input_ids = []
        if isinstance(input_ids[0], int):
            input_ids = [input_ids]

        # append an extra token since input_len-1 tokens will be
        # decoded by context decoder
        for inputs in input_ids:
            inputs.append(0)

        batch_size = len(input_ids)

        def _broadcast_np(data, dtype, shape=(batch_size, )):
            if isinstance(data, Iterable):
                assert len(data) == batch_size
                return data

            return np.full(shape, data, dtype=dtype)

        input_ids = [torch.IntTensor(ids) for ids in input_ids]
        input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
        input_ids = pad_sequence(input_ids,
                                 batch_first=True,
                                 padding_value=self.eos_id)

        inputs = dict(input_ids=input_ids,
                      input_lengths=input_lengths,
                      request_output_len=_broadcast_np(0, dtype=np.uint32),
                      is_return_logits=_broadcast_np(1, np.uint32))

        tm_inputs = _np_dict_to_tm_dict(inputs)

        # start forward thread
        self._forward_thread(tm_inputs)

        _, tm_outputs = self.que.get()

        outputs = _tm_dict_to_torch_dict(tm_outputs)
        logits = outputs['logits']

        return logits[:, :-1, :]