vllm_utils.py 22.2 KB
Newer Older
lvzhen's avatar
first  
lvzhen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
import concurrent.futures
import inspect
import os
import time
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Optional, Tuple

import torch
import vllm
from modelscope import GenerationConfig
from packaging import version
from torch import dtype as Dtype
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from vllm import AsyncEngineArgs, AsyncLLMEngine, EngineArgs, LLMEngine, SamplingParams

from swift.utils import get_logger
from .argument import InferArguments
from .model import get_model_tokenizer
from .template import Template, get_template

try:
    from vllm.lora.request import LoRARequest
except ImportError:
    pass

logger = get_logger()


def get_vllm_engine(
        model_type: str,
        torch_dtype: Optional[Dtype] = None,
        *,
        model_id_or_path: Optional[str] = None,
        revision: Optional[str] = None,
        gpu_memory_utilization: float = 0.9,
        tensor_parallel_size: int = 1,
        max_num_seqs: int = 256,
        max_model_len: Optional[int] = None,
        disable_custom_all_reduce: bool = True,  # Default values different from vllm
        enforce_eager: bool = False,
        engine_kwargs: Optional[Dict[str, Any]] = None,
        use_async: bool = False,
        # lora
        enable_lora: bool = False,
        max_loras: int = 1,
        max_lora_rank: int = 16,
        **kwargs) -> LLMEngine:
    model_dir = kwargs.pop('model_dir', None)  # compat with swift<1.7
    tokenizer = get_model_tokenizer(
        model_type,
        load_model=False,
        model_id_or_path=model_id_or_path,
        model_dir=model_dir,
        revision=revision,
        download_model=True)[1]
    model_dir = tokenizer.model_dir

    if engine_kwargs is None:
        engine_kwargs = {}
    dtype_mapping = {torch.float16: 'float16', torch.bfloat16: 'bfloat16', torch.float32: 'float32', None: 'auto'}
    dtype = dtype_mapping[torch_dtype]
    disable_log_stats = engine_kwargs.pop('disable_log_stats', True)

    if use_async:
        engine_args_cls = AsyncEngineArgs
        llm_engine_cls = AsyncLLMEngine
        engine_kwargs['disable_log_requests'] = True
    else:
        engine_args_cls = EngineArgs
        llm_engine_cls = LLMEngine

    parameters = inspect.signature(engine_args_cls.__init__).parameters
    if 'enable_lora' in parameters and enable_lora:
        engine_kwargs['enable_lora'] = enable_lora
        engine_kwargs['max_loras'] = max_loras
        engine_kwargs['max_lora_rank'] = max_lora_rank
    else:
        assert not enable_lora, 'The current version of VLLM does not support `enable_lora`. Please upgrade VLLM.'

    engine_args = engine_args_cls(
        model=model_dir,
        trust_remote_code=True,
        dtype=dtype,
        gpu_memory_utilization=gpu_memory_utilization,
        tensor_parallel_size=tensor_parallel_size,
        max_num_seqs=max_num_seqs,
        max_model_len=max_model_len,
        disable_log_stats=disable_log_stats,
        disable_custom_all_reduce=disable_custom_all_reduce,
        enforce_eager=enforce_eager,
        **engine_kwargs)
    try:
        from vllm.model_executor.parallel_utils.parallel_state import destroy_model_parallel
        destroy_model_parallel()
    except ImportError:
        pass
    # fix HTTPError bug (use model_dir)
    os.environ.pop('VLLM_USE_MODELSCOPE', None)
    if version.parse(vllm.__version__) >= version.parse('0.5.1'):
        os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

    llm_engine = llm_engine_cls.from_engine_args(engine_args)
    llm_engine.engine_args = engine_args
    llm_engine.model_dir = model_dir
    llm_engine.model_type = model_type

    if use_async:
        _engine = llm_engine.engine
    else:
        _engine = llm_engine
    model_config = _engine.model_config
    llm_engine.model_config = model_config
    llm_engine.dtype = model_config.dtype  # compat with pt
    llm_engine.max_model_len = model_config.max_model_len
    llm_engine.is_multimodal = tokenizer.is_multimodal
    # compatible with vllm==0.3.*
    if version.parse(vllm.__version__) >= version.parse('0.3'):
        assert isinstance(_engine.tokenizer.tokenizer, PreTrainedTokenizerBase)
        _engine.tokenizer.tokenizer = tokenizer

        # fix vllm==0.4 bug (very slow)
        if version.parse(vllm.__version__) >= version.parse('0.4'):
            _tokenizer_len = len(tokenizer)
            __old_len__ = tokenizer.__class__.__len__

            def __len__(self) -> int:
                if self is tokenizer:
                    return _tokenizer_len
                else:
                    return __old_len__(self)

            tokenizer.__class__.__len__ = __len__

    else:
        assert isinstance(_engine.tokenizer, PreTrainedTokenizerBase)
        _engine.tokenizer = tokenizer

    llm_engine.hf_tokenizer = tokenizer
    generation_config_path = os.path.join(model_dir, 'generation_config.json')
    if os.path.isfile(generation_config_path):
        generation_config = GenerationConfig.from_pretrained(model_dir)
        kwargs = generation_config.to_dict()
        parameters = inspect.signature(VllmGenerationConfig.__init__).parameters
        for k in kwargs.copy().keys():
            if k not in parameters:
                kwargs.pop(k)
        llm_engine.generation_config = VllmGenerationConfig(**kwargs)
    else:
        llm_engine.generation_config = VllmGenerationConfig()
    return llm_engine


class VllmGenerationConfig(SamplingParams):

    def __init__(
        self,
        max_new_tokens: Optional[int] = 64,  # max_tokens
        temperature: float = 1.,
        top_k: int = 50,  # -1: all
        top_p: float = 1.,
        repetition_penalty: float = 1.,
        num_beams: int = 1,
        *,
        n: int = 1,
        seed: Optional[int] = None,
        length_penalty: float = 1.,
        stop: Optional[List[str]] = None,
        skip_special_tokens: bool = False,
        **kwargs,
    ) -> None:
        # The parameter design is similar to transformers.GenerationConfig.
        if max_new_tokens is None:
            max_new_tokens = 64
        if num_beams > 1:
            top_k = -1
            top_p = 1
            temperature = 0
            logger.warning(
                'The output of num_beams in vllm may not be consistent with the output of num_beams in transformers.')
        if top_k == 0:
            top_k = -1
        if stop is None:
            stop = []
        kwargs['max_tokens'] = max_new_tokens
        kwargs['temperature'] = temperature
        kwargs['top_k'] = top_k
        kwargs['top_p'] = top_p
        kwargs['repetition_penalty'] = repetition_penalty
        if num_beams > 1:
            best_of = kwargs.get('best_of')
            assert 'use_beam_search' not in kwargs and best_of is None
            kwargs['use_beam_search'] = True
            kwargs['best_of'] = num_beams
        kwargs['n'] = n
        kwargs['seed'] = seed
        kwargs['length_penalty'] = length_penalty
        kwargs['stop'] = stop
        kwargs['skip_special_tokens'] = skip_special_tokens
        parameters = inspect.signature(SamplingParams.__init__).parameters
        for k in kwargs.copy().keys():
            if k not in parameters:
                logger.info(f'The VLLM version is too old and does not support the parameter: {k}.')
                kwargs.pop(k)
        self._temperature = temperature
        super().__init__(**kwargs)

    def __setattr__(self, key: str, value: str) -> None:
        if key == 'max_new_tokens':
            self.max_tokens = value
        elif key == 'do_sample':
            assert value in {True, False}
            if value:
                self.temperature = self._temperature
            else:
                self.temperature = 0.
        elif key == 'max_length':
            raise ValueError('`max_length` is not supported, please use `max_new_tokens` for setting.')
        else:
            super().__setattr__(key, value)


def _add_vllm_request(llm_engine: LLMEngine, inputs: Dict[str, Any], *, request_id: str,
                      generation_config: VllmGenerationConfig, **kwargs) -> None:
    input_ids = inputs['input_ids']
    if version.parse(vllm.__version__) >= version.parse('0.4.3'):
        llm_inputs = {'prompt_token_ids': input_ids}
        images = inputs.get('images') or []
        if images:
            assert len(images) == 1, 'Currently, only one image is supported.'
            llm_inputs['multi_modal_data'] = {'image': images[0]}
        llm_engine.add_request(request_id, llm_inputs, generation_config, **kwargs)
    else:
        llm_engine.add_request(request_id, None, generation_config, input_ids, **kwargs)


def _prepare_vllm_request(llm_engine: LLMEngine,
                          template: Template,
                          request_list: List[Dict[str, Any]],
                          *,
                          generation_config: VllmGenerationConfig,
                          generation_info: Dict[str, Any],
                          lora_request: Optional['LoRARequest'] = None,
                          use_tqdm: bool = False,
                          **kwargs) -> Tuple[List[Optional[Dict[str, Any]]], List[Tuple[bool, int]]]:
    for key in ['num_prompt_tokens', 'num_generated_tokens', 'num_samples']:
        if key not in generation_info:
            generation_info[key] = 0

    template.model = llm_engine
    tokenizer = template.tokenizer
    if tokenizer.eos_token is not None and tokenizer.eos_token not in generation_config.stop:
        generation_config.stop.append(tokenizer.eos_token)
    if isinstance(template.suffix[-1], str) and template.suffix[-1] not in generation_config.stop:
        generation_config.stop.append(template.suffix[-1])
    if isinstance(template.suffix[-1], list):
        token_str = tokenizer.decode(template.suffix[-1])
        if token_str not in generation_config.stop:
            generation_config.stop.append(token_str)

    parameters = inspect.signature(llm_engine.add_request).parameters
    add_request_kwargs = {}
    if 'lora_request' in parameters:
        add_request_kwargs['lora_request'] = lora_request
    else:
        assert lora_request is None, (
            'The current version of VLLM does not support `lora_request`. Please upgrade VLLM.')

    resp_list: List[Optional[Dict[str, Any]]] = [None] * len(request_list)
    agent_state = []
    is_multimodal = getattr(llm_engine, 'is_multimodal', False)
    max_workers = os.cpu_count()
    if not is_multimodal:
        use_tqdm = False
        max_workers = 1

    prog_bar = tqdm(request_list, dynamic_ncols=True, disable=not use_tqdm)

    def _prepare_inputs(request: Dict[str, Any]) -> Dict[str, Any]:
        history = request.get('history') or []
        # agent support
        is_observation = history[-1][-1].endswith('Observation:') if history and history[-1][-1] else False
        act_length = None
        if is_observation:
            history[-1][-1] = history[-1][-1] + request['query']
            act_length = len(history[-1][-1])
            request['query'] = None
        agent_state.append((is_observation, act_length))
        request['history'] = history

        inputs = template.encode(request)[0]
        prog_bar.update()
        return inputs

    with template.vllm_context(), concurrent.futures.ThreadPoolExecutor(
            max_workers=min(max_workers, len(request_list))) as executor:
        futures = [executor.submit(_prepare_inputs, request) for request in request_list]
        concurrent.futures.wait(futures)
        inputs_list = [future.result() for future in futures]
    prog_bar.close()

    for i, (inputs, request) in enumerate(zip(inputs_list, request_list)):
        truncation_strategy = kwargs.pop('truncation_strategy', 'delete')
        if len(inputs) == 0 and truncation_strategy == 'delete':
            # input_ids exceeds `max_length`. Please increase the value of `max_length`.
            resp_list[i] = {'response': '', 'history': request['history']}
            continue
        generation_info['num_prompt_tokens'] += len(inputs['input_ids'])
        generation_info['num_samples'] += 1
        _add_vllm_request(
            llm_engine, inputs, request_id=str(i), generation_config=generation_config, **add_request_kwargs)
    return resp_list, agent_state


@torch.inference_mode()
def inference_stream_vllm(
        llm_engine: LLMEngine,
        template: Template,
        request_list: List[Dict[str, Any]],
        *,
        generation_config: Optional[VllmGenerationConfig] = None,
        generation_info: Optional[Dict[str, Any]] = None,
        lora_request: Optional['LoRARequest'] = None,
        use_tqdm: bool = False,
        flush_steps: Optional[int] = None,  # Ensuring efficiency
        **kwargs) -> Iterator[List[Dict[str, Any]]]:
    """
    request_list: e.g. [{'query': 'hello!'}].
        The keys that can be included are: 'query', 'history', 'system', 'images'.
    generation_config: Priority: generation_config > model.generation_config.
    return: e.g. [{'response': 'hi!', 'history': [('hello!', 'hi!')]}].
        The keys to be included will be: 'response', 'history'.
    """
    if len(request_list) == 0:
        return
    start_runtime = time.perf_counter()
    if generation_config is None:
        generation_config = getattr(llm_engine, 'generation_config', VllmGenerationConfig())
    assert isinstance(generation_config, VllmGenerationConfig)
    request_list = deepcopy(request_list)
    generation_config = deepcopy(generation_config)
    if generation_info is None:
        generation_info = {}
    else:
        generation_info.clear()

    resp_list, agent_state = _prepare_vllm_request(
        llm_engine,
        template,
        request_list,
        generation_config=generation_config,
        generation_info=generation_info,
        lora_request=lora_request,
        use_tqdm=use_tqdm,
        **kwargs)

    if generation_config.use_beam_search:
        error_msg = 'Streaming generation does not support beam search.'
        raise ValueError(error_msg)

    n_finished = 0
    n_steps = 0
    if flush_steps is None:
        flush_steps = min(10, generation_info['num_samples'])
    print_idx_list = [[0] for _ in range(len(request_list))]
    num_generated_tokens = [0] * len(request_list)
    prog_bar = tqdm(total=generation_info['num_samples'], dynamic_ncols=True, disable=not use_tqdm)
    while llm_engine.has_unfinished_requests():
        is_flush = False
        n_steps += 1
        step_outputs = llm_engine.step()
        for output in step_outputs:
            if not output.finished and n_steps % flush_steps != 0:
                continue
            is_flush = True
            i = int(output.request_id)
            request = request_list[i]
            generate_ids = output.outputs[0].token_ids
            safe_response = template.generate_ids_to_response(
                generate_ids, output.finished, print_idx=print_idx_list[i])
            query = request['query']
            history = request['history']
            if resp_list[i] is None and not agent_state[i][0]:
                history.append(None)
            if not agent_state[i][0]:
                history[-1] = [query, safe_response]
            else:
                history[-1][-1] = history[-1][-1][:agent_state[i][1]] + safe_response

            n_gen_tokens = sum(len(_output.token_ids) for _output in output.outputs)
            generation_info['num_generated_tokens'] += n_gen_tokens - num_generated_tokens[i]
            num_generated_tokens[i] = n_gen_tokens

            resp_list[i] = {'response': safe_response, 'history': history}
            if output.finished:
                n_finished += 1
                prog_bar.update()
        if not is_flush:
            continue
        runtime = time.perf_counter() - start_runtime
        generation_info['runtime'] = runtime
        generation_info['samples/s'] = n_finished / runtime
        generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
        yield resp_list
    prog_bar.close()


@torch.inference_mode()
def inference_vllm(llm_engine: LLMEngine,
                   template: Template,
                   request_list: List[Dict[str, Any]],
                   *,
                   generation_config: Optional[VllmGenerationConfig] = None,
                   generation_info: Optional[Dict[str, Any]] = None,
                   max_batch_size: Optional[int] = None,
                   lora_request: Optional['LoRARequest'] = None,
                   use_tqdm: bool = False,
                   verbose: bool = False,
                   prompt_prefix: str = '[PROMPT]',
                   output_prefix: str = '[OUTPUT]',
                   **kwargs) -> List[Dict[str, Any]]:
    """
    request_list: e.g. [{'query': 'hello!'}].
        The keys that can be included are: 'query', 'history', 'system', 'images'.
    generation_config: Priority: generation_config > model.generation_config.
    return: e.g. [{'response': 'hi!', 'history': [('hello!', 'hi!')]}].
        The keys to be included will be: 'response', 'history'.
    """
    if len(request_list) == 0:
        return []
    runtime = time.perf_counter()

    is_multimodal = getattr(llm_engine, 'is_multimodal', False)
    if is_multimodal and max_batch_size is None:
        max_batch_size = 512

    _inner_call = kwargs.get('_inner_call', False)
    if generation_info is None:
        generation_info = {}
    elif not _inner_call:
        generation_info.clear()
    if max_batch_size is not None and len(request_list) > max_batch_size:
        i = 0
        resp_list = []
        kwargs['_inner_call'] = True
        while i < len(request_list):
            resp_list += inference_vllm(
                llm_engine,
                template,
                request_list[i:i + max_batch_size],
                generation_config=generation_config,
                generation_info=generation_info,
                max_batch_size=max_batch_size,
                lora_request=lora_request,
                use_tqdm=use_tqdm,
                verbose=verbose,
                prompt_prefix=prompt_prefix,
                output_prefix=output_prefix,
                **kwargs)
            i += max_batch_size
        runtime = time.perf_counter() - runtime
        generation_info['runtime'] = runtime
        generation_info['samples/s'] = generation_info['num_samples'] / runtime
        generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
        return resp_list

    if generation_config is None:
        generation_config = getattr(llm_engine, 'generation_config', VllmGenerationConfig())
    assert isinstance(generation_config, VllmGenerationConfig)
    request_list = deepcopy(request_list)
    generation_config = deepcopy(generation_config)

    old_num_samples = generation_info.get('num_samples', 0)
    resp_list, agent_state = _prepare_vllm_request(
        llm_engine,
        template,
        request_list,
        generation_config=generation_config,
        generation_info=generation_info,
        lora_request=lora_request,
        use_tqdm=use_tqdm,
        **kwargs)

    tokenizer = template.tokenizer
    if use_tqdm:
        assert verbose is False
    prog_bar = tqdm(total=generation_info['num_samples'] - old_num_samples, dynamic_ncols=True, disable=not use_tqdm)
    outputs = []
    while llm_engine.has_unfinished_requests():
        step_outputs = llm_engine.step()
        for output in step_outputs:
            if output.finished:
                outputs.append(output)
                prog_bar.update()
    prog_bar.close()

    for output in outputs:
        i = int(output.request_id)
        request = request_list[i]
        generate_ids = output.outputs[0].token_ids
        response = template.generate_ids_to_response(generate_ids)
        query = request['query']
        history = request['history']
        if not agent_state[i][0]:
            history.append([query, response])
        else:
            history[-1][-1] = history[-1][-1] + response

        generation_info['num_generated_tokens'] += sum(len(_output.token_ids) for _output in output.outputs)
        resp_list[i] = {'response': response, 'history': history}
        if verbose:
            print(f'{prompt_prefix}{tokenizer.decode(output.prompt_token_ids, False)}{output_prefix}', end='')
            print(tokenizer.decode(output.outputs[0].token_ids, False))
    runtime = time.perf_counter() - runtime
    generation_info['runtime'] = runtime
    generation_info['samples/s'] = generation_info['num_samples'] / runtime
    generation_info['tokens/s'] = generation_info['num_generated_tokens'] / runtime
    return resp_list


def prepare_vllm_engine_template(args: InferArguments, use_async: bool = False) -> Tuple[LLMEngine, Template]:
    logger.info(f'device_count: {torch.cuda.device_count()}')

    assert not (args.sft_type == 'lora' and not args.vllm_enable_lora), 'you need to merge lora'
    # Loading Model and Tokenizer
    model_id_or_path = None
    if args.sft_type == 'full' and args.ckpt_dir is not None:
        model_id_or_path = args.ckpt_dir
    elif args.model_id_or_path is not None:
        model_id_or_path = args.model_id_or_path
    llm_engine = get_vllm_engine(
        args.model_type,
        args.torch_dtype,
        gpu_memory_utilization=args.gpu_memory_utilization,
        tensor_parallel_size=args.tensor_parallel_size,
        max_num_seqs=args.max_num_seqs,
        max_model_len=args.max_model_len,
        disable_custom_all_reduce=args.disable_custom_all_reduce,
        enforce_eager=args.enforce_eager,
        use_async=use_async,
        model_id_or_path=model_id_or_path,
        enable_lora=args.vllm_enable_lora,
        max_loras=min(len(args.lora_modules), 1),
        max_lora_rank=args.vllm_max_lora_rank)
    tokenizer = llm_engine.hf_tokenizer

    if not args.do_sample:
        args.temperature = 0
    generation_config = VllmGenerationConfig(
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        stop=args.stop_words,
        repetition_penalty=args.repetition_penalty,
        num_beams=args.num_beams)
    logger.info(f'generation_config: {generation_config}')
    llm_engine.generation_config = generation_config
    template: Template = get_template(
        args.template_type,
        tokenizer,
        args.system,
        args.max_length,
        args.truncation_strategy,
        model=llm_engine,
        tools_prompt=args.tools_prompt)
    args.system = template.default_system
    logger.info(f'system: {args.system}')
    return llm_engine, template