".github/workflows/release-whl-kernel-aarch64.yml" did not exist on "88defc4d89b766ce2ed9d0828d31f583b094c278"
turbomind.py 27.3 KB
Newer Older
q.yao's avatar
q.yao committed
1
# Copyright (c) OpenMMLab. All rights reserved.
AllentDan's avatar
AllentDan committed
2
import asyncio
3
4
5
6
import copy
import io
import json
import logging
q.yao's avatar
q.yao committed
7
import os.path as osp
8
import sys
q.yao's avatar
q.yao committed
9
10
from configparser import ConfigParser
from contextlib import contextmanager
q.yao's avatar
q.yao committed
11
12
from queue import Queue
from threading import Thread
13
from typing import Iterable, List, Optional
14

q.yao's avatar
q.yao committed
15
import numpy as np
16
import torch
17
from huggingface_hub import snapshot_download
q.yao's avatar
q.yao committed
18
19
from torch.nn.utils.rnn import pad_sequence

20
import lmdeploy
21
from lmdeploy.model import MODELS, BaseModel
22
from lmdeploy.tokenizer import Tokenizer
23
from lmdeploy.utils import get_logger
24

q.yao's avatar
q.yao committed
25
26
from .deploy.converter import (get_model_format, supported_formats,
                               update_config_weight_type, update_output_format)
27
28
29
30
31
from .deploy.source_model.base import INPUT_MODELS
from .deploy.target_model.base import OUTPUT_MODELS, TurbomindModelConfig
from .utils import (ModelSource, check_tm_model_input, create_hf_download_args,
                    get_hf_config_content, get_model_source)

q.yao's avatar
q.yao committed
32
33
34
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
sys.path.append(osp.join(lmdeploy_dir, 'lib'))
35
import _turbomind as _tm  # noqa: E402
q.yao's avatar
q.yao committed
36

37
38
logger = logging.getLogger(__name__)

q.yao's avatar
q.yao committed
39

40
def _stop_words(stop_words: List[str], tokenizer: Tokenizer):
lvhan028's avatar
lvhan028 committed
41
    """return list of stop-words to numpy.ndarray."""
q.yao's avatar
q.yao committed
42
43
44
    if stop_words is None:
        return None
    assert isinstance(stop_words, List) and \
45
46
        all(isinstance(elem, str) for elem in stop_words), \
        f'stop_words must be a list but got {type(stop_words)}'
AllentDan's avatar
AllentDan committed
47
48
49
    stop_words = [
        tokenizer.encode(stop_word, False)[-1] for stop_word in stop_words
    ]
50
51
    assert isinstance(stop_words, List) and all(
        isinstance(elem, int) for elem in stop_words), 'invalid stop_words'
q.yao's avatar
q.yao committed
52
53
54
55
56
57
58
59
60
    # each id in stop_words represents a stop word
    # refer to https://github.com/fauxpilot/fauxpilot/discussions/165 for
    # detailed explanation about fastertransformer's stop_words
    stop_word_offsets = range(1, len(stop_words) + 1)
    stop_words = np.array([[stop_words, stop_word_offsets]]).astype(np.int32)
    return stop_words


def _np_dict_to_tm_dict(np_dict: dict):
lvhan028's avatar
lvhan028 committed
61
    """map numpy.ndarray to turbomind's tensor."""
q.yao's avatar
q.yao committed
62
63
64
65
66
67
68
69
    ret = _tm.TensorMap()
    for k, v in np_dict.items():
        ret[k] = _tm.from_dlpack(v)

    return ret


def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):
lvhan028's avatar
lvhan028 committed
70
    """map turbomind's tensor to torch's tensor."""
q.yao's avatar
q.yao committed
71
72
73
74
75
76
77
78
79
    ret = dict()
    for k, v in tm_dict.items():
        if v.type == _tm.DataType.TYPE_UINT32:
            v = v.view(_tm.DataType.TYPE_INT32)
        ret[k] = torch.from_dlpack(v)

    return ret


q.yao's avatar
q.yao committed
80
81
82
83
84
85
86
87
@contextmanager
def cuda_ctx(device_id):
    old_device = torch.cuda.current_device()
    torch.cuda.set_device(device_id)
    yield
    torch.cuda.set_device(old_device)


q.yao's avatar
q.yao committed
88
class TurboMind:
lvhan028's avatar
lvhan028 committed
89
90
91
92
    """LMDeploy's inference engine.

    Args:
        model_path (str): the path of turbomind's model
93
94
95
96
97
98
99
        model_source (int): model source
        model_name (str): needed when model_path is a hf model and not
            managed by lmdeploy
        model_format (str): needed when model_path is a hf model and not
            managed by lmdeploy
        group_size (int): needed when model_path is a hf model and not
            managed by lmdeploy
100
        tp (int): tensor parallel
lvhan028's avatar
lvhan028 committed
101
    """
q.yao's avatar
q.yao committed
102

103
104
    def __init__(self,
                 model_path: str,
105
106
107
108
109
                 model_source: ModelSource = ModelSource.WORKSPACE,
                 model_name: Optional[str] = None,
                 model_format: Optional[str] = None,
                 group_size: Optional[int] = None,
                 tp: Optional[int] = None,
110
                 **kwargs):
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        if tp is not None:
            assert ((tp & (tp - 1) == 0) and tp != 0), 'tp should be 2^n'
        self.gpu_count = tp if tp is not None else 1

        if model_source == ModelSource.WORKSPACE:
            tokenizer_model_path = osp.join(model_path, 'triton_models',
                                            'tokenizer')
            self.tokenizer = Tokenizer(tokenizer_model_path)
            self.model_comm = self._from_workspace(model_path)
        else:
            self.tokenizer = Tokenizer(model_path)
            self.model_comm = self._from_hf(model_source=model_source,
                                            model_path=model_path,
                                            model_name=model_name,
                                            model_format=model_format,
                                            group_size=group_size,
                                            tp=tp,
                                            **kwargs)

        self.eos_id = self.tokenizer.eos_token_id
131
132
        self.model: BaseModel = MODELS.get(self.model_name)(**kwargs)
        self.session_len = self.model.session_len
133
        self.stop_words = _stop_words(self.model.stop_words, self.tokenizer)
q.yao's avatar
q.yao committed
134

135
136
    def _create_weight(self, model_comm):
        """Allocate weight buffer, load params if from_workspace."""
q.yao's avatar
q.yao committed
137

138
139
140
        # TODO: support mpi
        self.node_id = 0
        self.node_num = 1
141
        self.nccl_params = model_comm.create_nccl_params(self.node_id)
q.yao's avatar
q.yao committed
142
143
144
        torch.cuda.synchronize()

        # create weight
145
        def _create_weight_func(device_id):
q.yao's avatar
q.yao committed
146
147
            with cuda_ctx(device_id):
                rank = self.node_id * self.gpu_count + device_id
148
                model_comm.create_shared_weights(device_id, rank)
q.yao's avatar
q.yao committed
149
150
151

        threads = []
        for device_id in range(self.gpu_count):
152
            t = Thread(target=_create_weight_func, args=(device_id, ))
q.yao's avatar
q.yao committed
153
154
155
156
157
            t.start()
            threads.append(t)
        for t in threads:
            t.join()

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    def _load_kv_qparams(self, model_path, tm_params, **kwargs):
        """Load kv qparams when loading from hf."""
        if self.config.quant_policy:
            logger.warning('loading kv_cache quant scale')
            from lmdeploy.lite.apis.kv_qparams import main as kv_loader
            kv_sym = kwargs.get('kv_sym', False)
            kv_bits = kwargs.get('kv_bits', 8)
            tp = self.config.tensor_para_size
            kv_loader(model_path, model_path, kv_bits, kv_sym, tp, tm_params)
        else:
            for key in list(tm_params.keys()):
                if 'past_kv_scale' in key:
                    tm_params.pop(key)

    def _get_model_params(self, model_comm, tm_params):
        """Get turbomind model params when loading from hf."""

        def _get_params(device_id, que):
            with cuda_ctx(device_id):
                rank = self.node_id * self.gpu_count + device_id
                out = model_comm.get_params(device_id, rank)
                que.put(out)

        que = Queue()
        threads = []
        for device_id in range(self.gpu_count):
            t = Thread(target=_get_params, args=(device_id, que))
            t.start()
            threads.append(t)
        for t in threads:
            t.join()

        for _ in range(self.gpu_count):
            tensor_map = que.get()
            for k, v in tensor_map.items():
                if k not in tm_params:
                    tm_params[k] = []
                tm_params[k].append(v)

    def _from_hf(self,
                 model_source: ModelSource,
                 model_path: str,
                 model_name: Optional[str] = None,
                 model_format: Optional[str] = None,
                 group_size: Optional[int] = None,
                 tp: Optional[int] = None,
                 **kwargs):
        """Load model which is in hf format."""
        # get model_name, group_size if is lmdeploy managed.
        if model_source == ModelSource.HF_LMDEPLOY:
            config = get_hf_config_content(model_path, local_files_only=True)
            tm_config = config['turbomind']
            tm_config.update(kwargs)
            var_shoud_be_none = dict(model_name=model_name,
                                     model_format=model_format,
                                     group_size=group_size)
            for key, value in var_shoud_be_none.items():
                assert value is None, f'{key} should be None when model is '\
                    f'from {model_source}'
            model_name = tm_config['model_name']
            group_size = tm_config['group_size']
            if tm_config['weight_type'] == 'int4':
                model_format = 'awq'
        else:
            assert model_name is not None, 'please supply model_name when ' \
                f'model is form {model_source}'
            if osp.exists(osp.join(model_path, 'outputs_stats.pth')):
                model_format = 'awq' if model_format is None else model_format
                group_size = 128 if group_size is None else group_size
            tm_config = kwargs

        assert model_name in MODELS.module_dict.keys(), \
            f"'{model_name}' is not supported. " \
            f'The supported models are: {MODELS.module_dict.keys()}'
        assert model_format in supported_formats, 'the model format ' \
            f'should be in {supported_formats}'

        data_type = 'fp16'
        output_format = 'fp16'
        inferred_model_format = get_model_format(model_name, model_format)
        cfg = TurbomindModelConfig.from_dict(tm_config, allow_none=True)

        # overwrite with input params
        cfg.model_name = model_name
        cfg.tensor_para_size = 1 if tp is None else tp
        cfg.rotary_embedding = cfg.size_per_head
        cfg.group_size = group_size
        if inferred_model_format.find('awq') != -1:
            cfg.weight_type = 'int4'
            output_format = 'w4'
            data_type = 'int4'
            assert group_size > 0, f'group_size: {group_size} should > 0'
q.yao's avatar
q.yao committed
250
251
252
253
254
255
        else:
            output_format = update_output_format(model_name,
                                                 inferred_model_format,
                                                 model_path, output_format)
            data_type = output_format
            update_config_weight_type(output_format, cfg)
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381

        self.config = cfg
        self.model_name = model_name
        self.data_type = data_type

        input_model = INPUT_MODELS.get(inferred_model_format)(
            model_path=model_path, tokenizer_path=model_path, ckpt_path=None)

        output_model = OUTPUT_MODELS.get(output_format)(
            input_model=input_model, cfg=cfg, to_file=False, out_dir='')

        config = copy.deepcopy(output_model.cfg.__dict__)
        logger.warning(f'model_config:\n{json.dumps(config, indent=2)}')
        parser = ConfigParser()
        parser['llama'] = config
        with io.StringIO() as ss:
            parser.write(ss)
            ss.seek(0)
            config = ss.read()

        model_comm = _tm.AbstractTransformerModel.create_llama_model(
            model_dir='',
            config=config,
            tensor_para_size=self.gpu_count,
            data_type=data_type)

        # create empty weight
        self._create_weight(model_comm)

        # copy hf model weight to turbomind weight
        tm_params = output_model.tm_params
        self._get_model_params(model_comm, tm_params)
        logger.warning(f'get {len(tm_params)} model params')
        output_model.export()

        # load kv qparams
        self._load_kv_qparams(model_path, tm_params, **kwargs)
        assert len(tm_params) == 0, f'missing {tm_params.keys()}'

        return model_comm

    def _from_workspace(self, model_path: str):
        """Load model which is converted by `lmdeploy convert`"""
        ini_path = osp.join(model_path, 'triton_models', 'weights',
                            'config.ini')
        with open(ini_path, 'r') as f:
            parser = ConfigParser()
            parser.read_file(f)
            section_name = 'llama'
            tp_cfg = parser.getint(section_name, 'tensor_para_size')

            if tp_cfg != 1 and tp_cfg != self.gpu_count:
                get_logger('turbomind').info(
                    f'found tp={tp_cfg} in config.ini.')
                self.gpu_count = tp_cfg
            self.model_name = parser.get(section_name, 'model_name')
            self.data_type = parser.get(section_name, 'weight_type')
            cfg = parser._sections[section_name]
            cfg = TurbomindModelConfig.from_dict(cfg)
            self.config = cfg

        # create model
        weight_dir = osp.join(model_path, 'triton_models', 'weights')
        model_comm = _tm.AbstractTransformerModel.create_llama_model(
            weight_dir,
            tensor_para_size=self.gpu_count,
            data_type=self.data_type)

        # create weight and load params
        self._create_weight(model_comm)
        return model_comm

    @classmethod
    def from_pretrained(cls,
                        pretrained_model_name_or_path: str,
                        model_name: Optional[str] = None,
                        model_format: Optional[str] = None,
                        group_size: Optional[int] = None,
                        tp: Optional[int] = None,
                        **kwargs):
        """LMDeploy's turbomind inference engine.

        Args:
            pretrained_model_name_or_path (str):
                It could be one of the following options:
                    - i) A local directory path of a turbomind model which is
                      converted by `lmdeploy convert` command or download from
                      ii) and iii)
                    - ii) The model_id of a lmdeploy-quantized model hosted
                      inside a model repo on huggingface.co, such as
                      "InternLM/internlm-chat-20b-4bit",
                      "lmdeploy/llama2-chat-70b-4bit", etc.
                    - iii) The model_id of a model hosted inside a model repo
                      on huggingface.co, such as "InternLM/internlm-chat-7b",
                      "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
                      and so on.
            model_name (str): needed when pretrained_model_name_or_path is c)
            model_format (str): model format
            group_size (int): group size
            tp (int): tensor parallel size
            kwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update configuration when initialize the engine.
        """
        model_source = get_model_source(pretrained_model_name_or_path)
        if model_source == ModelSource.WORKSPACE:
            local_path = pretrained_model_name_or_path
        else:
            check_tm_model_input(pretrained_model_name_or_path,
                                 model_name=model_name,
                                 **kwargs)
            if not osp.exists(pretrained_model_name_or_path):
                download_kwargs = create_hf_download_args(**kwargs)
                local_path = snapshot_download(pretrained_model_name_or_path,
                                               **download_kwargs)
            else:
                local_path = pretrained_model_name_or_path

        logger.warning(f'model_source: {model_source}')
        return cls(model_source=model_source,
                   model_path=local_path,
                   model_name=model_name,
                   model_format=model_format,
                   group_size=group_size,
                   tp=tp,
                   **kwargs)

q.yao's avatar
q.yao committed
382
    def create_instance(self, cuda_stream_id=0):
lvhan028's avatar
lvhan028 committed
383
384
385
386
387
388
389
        """Create a turbomind instance.

        Args:
            cuda_stream_id(int): identity of a cuda stream
        Returns:
            TurboMindInstance: an instance of turbomind
        """
q.yao's avatar
q.yao committed
390
        return TurboMindInstance(self, cuda_stream_id)
q.yao's avatar
q.yao committed
391
392
393


class TurboMindInstance:
lvhan028's avatar
lvhan028 committed
394
395
396
397
398
399
    """Instance of TurboMind.

    Args:
        tm_model (str): turbomind's model path
        cuda_stream_id(int): identity of a cuda stream
    """
q.yao's avatar
q.yao committed
400

401
    def __init__(self, tm_model: TurboMind, cuda_stream_id: int = 0):
q.yao's avatar
q.yao committed
402
        self.tm_model = tm_model
q.yao's avatar
q.yao committed
403
404
405
406
        self.cuda_stream_id = cuda_stream_id

        self.node_id = tm_model.node_id
        self.gpu_count = tm_model.gpu_count
q.yao's avatar
q.yao committed
407
408

        self.stop_words = tm_model.stop_words
409
410
        self.stop_tokens = [] if self.stop_words is None else \
            self.stop_words.flatten().tolist()
q.yao's avatar
q.yao committed
411
412
413
        self.eos_id = tm_model.eos_id
        self.session_len = tm_model.session_len

q.yao's avatar
q.yao committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
        self.nccl_params = tm_model.nccl_params

        # create model instances
        model_insts = [None] * self.gpu_count
        threads = []
        for device_id in range(self.gpu_count):
            t = Thread(target=self._create_model_instance,
                       args=(device_id, model_insts))
            t.start()
            threads.append(t)
        for t in threads:
            t.join()

        self.model_insts = model_insts
q.yao's avatar
q.yao committed
428
        self.que = Queue()
q.yao's avatar
q.yao committed
429
430
431
432
433
        self.threads = [None] * self.gpu_count

    def _create_model_instance(self, device_id, model_insts):
        with cuda_ctx(device_id):
            rank = self.node_id * self.gpu_count + device_id
434
            model_inst = self.tm_model.model_comm.create_model_instance(
q.yao's avatar
q.yao committed
435
436
                device_id, rank, self.cuda_stream_id, self.nccl_params)
            model_insts[device_id] = model_inst
q.yao's avatar
q.yao committed
437
438
439
440
441

    def _forward_callback(self, result, ctx):
        self.que.put((False, result))

    def _forward_thread(self, inputs):
q.yao's avatar
q.yao committed
442
443
        instance_comm = self.tm_model.model_comm.create_instance_comm(
            self.gpu_count)
q.yao's avatar
q.yao committed
444

q.yao's avatar
q.yao committed
445
446
447
        def _func(device_id, enque_output):
            with cuda_ctx(device_id):
                output = self.model_insts[device_id].forward(
q.yao's avatar
q.yao committed
448
                    inputs, instance_comm)
q.yao's avatar
q.yao committed
449
450
                if enque_output:
                    self.que.put((True, output))
q.yao's avatar
q.yao committed
451

q.yao's avatar
q.yao committed
452
        for device_id in range(self.gpu_count):
q.yao's avatar
q.yao committed
453
454
455
            t = Thread(target=_func,
                       args=(device_id, device_id == 0),
                       daemon=True)
q.yao's avatar
q.yao committed
456
457
            t.start()
            self.threads[device_id] = t
q.yao's avatar
q.yao committed
458

AllentDan's avatar
AllentDan committed
459
460
461
462
463
464
465
    async def async_stream_infer(self, *args, **kwargs):
        """Async wrapper of self.stream_infer."""
        for output in self.stream_infer(*args, **kwargs):
            # Allow the pipeline add new requests into the queue.
            await asyncio.sleep(0)
            yield output

q.yao's avatar
q.yao committed
466
467
468
    def stream_infer(self,
                     session_id,
                     input_ids,
Chen Xin's avatar
Chen Xin committed
469
470
                     input_embeddings=None,
                     input_embedding_ranges=None,
q.yao's avatar
q.yao committed
471
472
473
                     request_output_len: int = 512,
                     sequence_start: bool = True,
                     sequence_end: bool = False,
474
                     step=0,
q.yao's avatar
q.yao committed
475
476
477
478
                     stop=False,
                     top_p=0.8,
                     top_k=40,
                     temperature=0.8,
479
                     repetition_penalty=1.0,
q.yao's avatar
q.yao committed
480
                     ignore_eos=False,
q.yao's avatar
q.yao committed
481
482
                     random_seed=None,
                     stream_output=False):
lvhan028's avatar
lvhan028 committed
483
484
485
486
487
        """Perform model inference.

        Args:
            session_id (int): the id of a session
            input_ids (numpy.ndarray): the token ids of a prompt
Chen Xin's avatar
Chen Xin committed
488
489
490
            input_embeddings (List[numpy.ndarray]): embeddings features
            input_embedding_ranges (List[Tuple[int,int]]): the begin/end
              offsets of input_embeddings to input_ids
lvhan028's avatar
lvhan028 committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
            request_output_len (int): the max number of to-be-generated tokens
            sequence_start (bool): indicator for starting a sequence
            sequence_end (bool): indicator for ending a sequence
            step (int): the offset of the k/v cache
            stop (bool): indicator for cancelling the session
            top_p (float): If set to float < 1, only the smallest set of most
              probable tokens with probabilities that add up to top_p or higher
            are kept for generation.
            top_k (int): The number of the highest probability vocabulary
              tokens to keep for top-k-filtering
            temperature (float): to modulate the next token probability
            repetition_penalty (float): The parameter for repetition penalty.
              1.0 means no penalty
            ignore_eos (bool): indicator for ignoring eos
            random_seed (int): seed used by sampling
            stream_output (bool): indicator for stream output
        """
q.yao's avatar
q.yao committed
508
        if stream_output and not stop:
q.yao's avatar
q.yao committed
509
            self.model_insts[0].register_callback(self._forward_callback)
q.yao's avatar
q.yao committed
510
511

        if len(input_ids) == 0:
512
            input_ids = [[]]
q.yao's avatar
q.yao committed
513
514
515
516
517
518
519
520
521
522
523
524
525
526
        if isinstance(input_ids[0], int):
            input_ids = [input_ids]

        batch_size = len(input_ids)

        def _broadcast_np(data, dtype, shape=(batch_size, )):
            if isinstance(data, Iterable):
                assert len(data) == batch_size
                return data

            return np.full(shape, data, dtype=dtype)

        input_ids = [torch.IntTensor(ids) for ids in input_ids]
        input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
527
528
529
        input_ids = pad_sequence(input_ids,
                                 batch_first=True,
                                 padding_value=self.eos_id)
q.yao's avatar
q.yao committed
530
531
532
533
534

        if isinstance(session_id, int):
            session_id = [session_id]
        assert len(session_id) == batch_size

q.yao's avatar
q.yao committed
535
536
        step = _broadcast_np(step, np.int32)

q.yao's avatar
q.yao committed
537
538
539
        inputs = dict(
            input_ids=input_ids,
            input_lengths=input_lengths,
540
541
542
            request_output_len=np.full(input_lengths.shape,
                                       request_output_len,
                                       dtype=np.uint32),
q.yao's avatar
q.yao committed
543
544
545
546
            runtime_top_k=_broadcast_np(top_k, np.uint32),
            runtime_top_p=_broadcast_np(top_p, np.float32),
            temperature=_broadcast_np(temperature, np.float32),
            repetition_penalty=_broadcast_np(repetition_penalty, np.float32),
q.yao's avatar
q.yao committed
547
            step=step,
q.yao's avatar
q.yao committed
548
549
550
551
552
553
554
555
556
557
558

            # session input
            session_len=self.session_len *
            np.ones([
                batch_size,
            ], dtype=np.uint32),
            START=_broadcast_np((1 if sequence_start else 0), np.int32),
            END=_broadcast_np((1 if sequence_end else 0), np.int32),
            CORRID=np.array(session_id, dtype=np.uint64),
            STOP=_broadcast_np((1 if stop else 0), np.int32))

Chen Xin's avatar
Chen Xin committed
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        if input_embeddings is not None:
            assert len(input_embeddings) == len(input_embedding_ranges)
            if isinstance(input_embeddings[0], np.ndarray):
                input_embeddings = [input_embeddings]
                input_embedding_ranges = [input_embedding_ranges]
            # convert to lookup table type
            if self.tm_model.config.weight_type == 'fp32':
                input_embeddings = [[x.astype(np.float32) for x in y]
                                    for y in input_embeddings]
            elif self.tm_model.config.weight_type == 'bf16':
                input_embeddings = [[
                    torch.from_numpy(x).bfloat16().view(torch.half).numpy()
                    for x in y
                ] for y in input_embeddings]
            else:
                input_embeddings = [[x.astype(np.float16) for x in y]
                                    for y in input_embeddings]

            input_embeddings = [[torch.from_numpy(x).squeeze() for x in y]
                                for y in input_embeddings]
            input_embeddings = [torch.cat(x) for x in input_embeddings]
            input_embeddings = pad_sequence(input_embeddings, batch_first=True)
            input_embeddings = input_embeddings.reshape(
                input_embeddings.shape[0], -1).view(torch.int8)

            _input_embedding_ranges = []
            for x in input_embedding_ranges:
                if x is not None and len(x) != 0:
                    _input_embedding_ranges.append(torch.IntTensor(x))
                else:
                    _input_embedding_ranges.append(torch.IntTensor(size=(0,
                                                                         2)))
            input_embedding_ranges = pad_sequence(_input_embedding_ranges,
                                                  batch_first=True,
                                                  padding_value=-1)
            inputs['input_embeddings'] = input_embeddings
            inputs['input_embedding_ranges'] = input_embedding_ranges

q.yao's avatar
q.yao committed
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
        if ignore_eos:
            stop_words = None
            bad_words = torch.tensor([[[self.eos_id], [1]]], dtype=torch.int32)
        else:
            stop_words = self.stop_words
            bad_words = None

        if stop_words is not None:
            inputs['stop_words_list'] = stop_words
        if bad_words is not None:
            inputs['bad_words_list'] = bad_words

        if random_seed is not None:
            inputs['random_seed'] = _broadcast_np(random_seed, np.uint64)
        tm_inputs = _np_dict_to_tm_dict(inputs)

q.yao's avatar
q.yao committed
613
        # start forward thread
614
        self.que = Queue()
q.yao's avatar
q.yao committed
615
616
617
618
619
620
621
622
623
624
625
626
627
628
        self._forward_thread(tm_inputs)

        seq_start = input_lengths + input_lengths.new_tensor(step)

        # generator
        while True:
            while self.que.qsize() > 1:
                self.que.get()

            finish, tm_outputs = self.que.get()

            outputs = _tm_dict_to_torch_dict(tm_outputs)

            output_ids = outputs['output_ids'][:, 0, :]
Li Zhang's avatar
Li Zhang committed
629
            sequence_length = outputs['sequence_length'].long()[:, 0]
q.yao's avatar
q.yao committed
630
631
632
633
634
            output_ids = [
                output_id[s:l] for output_id, s, l in zip(
                    output_ids, seq_start, sequence_length)
            ]
            sequence_length -= seq_start.to(sequence_length.device)
635
636
637
638

            outputs = []
            for output, len_ in zip(output_ids, sequence_length):
                output, len_ = output, len_.item()
639
640
                if len(output) > 0 and output[-1].item(
                ) == self.eos_id and not ignore_eos:
641
                    outputs.append((output[:-1], len_ - 1))
642
643
                elif len(output) > 0 and output[-1].item() in self.stop_tokens:
                    outputs.append((output[:-1], len_))
644
645
646
                else:
                    outputs.append((output, len_))
            yield outputs
q.yao's avatar
q.yao committed
647
648

            if finish:
q.yao's avatar
q.yao committed
649
650
                for t in self.threads:
                    t.join()
q.yao's avatar
q.yao committed
651
652
653
                while self.que.qsize() > 0:
                    self.que.get()
                break
q.yao's avatar
q.yao committed
654

q.yao's avatar
q.yao committed
655
        if stream_output and not stop:
q.yao's avatar
q.yao committed
656
            self.model_insts[0].unregister_callback()
657
658
659
660
661
662
663
664
665

    def decode(self, input_ids):
        """Perform context decode on input tokens.

        Args:
            input_ids (numpy.ndarray): the batch of input token ids
        """

        if len(input_ids) == 0:
666
            input_ids = [[]]
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
        if isinstance(input_ids[0], int):
            input_ids = [input_ids]

        # append an extra token since input_len-1 tokens will be
        # decoded by context decoder
        for inputs in input_ids:
            inputs.append(0)

        batch_size = len(input_ids)

        def _broadcast_np(data, dtype, shape=(batch_size, )):
            if isinstance(data, Iterable):
                assert len(data) == batch_size
                return data

            return np.full(shape, data, dtype=dtype)

        input_ids = [torch.IntTensor(ids) for ids in input_ids]
        input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
        input_ids = pad_sequence(input_ids,
                                 batch_first=True,
                                 padding_value=self.eos_id)

        inputs = dict(input_ids=input_ids,
                      input_lengths=input_lengths,
                      request_output_len=_broadcast_np(0, dtype=np.uint32),
                      is_return_logits=_broadcast_np(1, np.uint32))

        tm_inputs = _np_dict_to_tm_dict(inputs)

        # start forward thread
        self._forward_thread(tm_inputs)

        _, tm_outputs = self.que.get()

        outputs = _tm_dict_to_torch_dict(tm_outputs)
        logits = outputs['logits']

        return logits[:, :-1, :]