infer.py 23.6 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
# Copyright (c) Alibaba, Inc. and its affiliates.
import datetime as dt
import os
import shutil
from typing import Any, Dict, Literal, Optional, Tuple

import json
import numpy as np
import torch
from modelscope import BitsAndBytesConfig, GenerationConfig
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import is_torch_npu_available

from swift.tuners import Swift
from swift.utils import (append_to_jsonl, get_logger, get_main, get_model_info, read_multi_line, seed_everything,
                         show_layers)
from .utils import (DeployArguments, InferArguments, Template, get_additional_saved_files, get_dataset,
                    get_model_tokenizer, get_template, inference, inference_stream, is_adapter, sample_dataset,
                    set_generation_config)

logger = get_logger()


def save_checkpoint(model: Optional[PreTrainedModel],
                    tokenizer: PreTrainedTokenizerBase,
                    model_cache_dir: str,
                    ckpt_dir: Optional[str],
                    target_dir: str,
                    *,
                    save_safetensors: bool = True,
                    **kwargs) -> None:
    if model is not None:
        model.save_pretrained(target_dir, safe_serialization=save_safetensors)
    if hasattr(tokenizer, 'processor'):
        tokenizer.processor.save_pretrained(target_dir)
    tokenizer.save_pretrained(target_dir)
    model_type = getattr(tokenizer, 'model_type')
    fname_list = ['generation_config.json', 'preprocessor_config.json']
    if model_type is not None:
        fname_list += get_additional_saved_files(model_type)

    for fname in fname_list:
        tgt_path = os.path.join(target_dir, fname)
        for model_dir in [ckpt_dir, model_cache_dir]:
            if model_dir is None:
                continue
            src_path = os.path.join(model_dir, fname)
            if os.path.isfile(src_path):
                shutil.copy(src_path, tgt_path)
                break
            elif os.path.isdir(src_path):
                shutil.copytree(src_path, tgt_path)
                break
    # configuration.json
    configuration_fname = 'configuration.json'
    new_configuration_path = os.path.join(target_dir, configuration_fname)
    for model_dir in [ckpt_dir, model_cache_dir]:
        if model_dir is None:
            continue
        old_configuration_path = os.path.join(model_dir, configuration_fname)
        if os.path.exists(old_configuration_path):
            with open(old_configuration_path, 'r', encoding='utf-8') as f:
                res = json.load(f)
            res.pop('adapter_cfg', None)
            with open(new_configuration_path, 'w', encoding='utf-8') as f:
                json.dump(res, f, ensure_ascii=False, indent=4)
            break
    if ckpt_dir is not None:
        # sft_args.json
        sft_args_fname = 'sft_args.json'
        old_sft_args_path = os.path.join(ckpt_dir, sft_args_fname)
        new_sft_args_path = os.path.join(target_dir, sft_args_fname)
        if os.path.exists(old_sft_args_path):
            with open(old_sft_args_path, 'r', encoding='utf-8') as f:
                res = json.load(f)
            res['sft_type'] = 'full'
            dtype = kwargs.get('dtype')
            if dtype is not None:
                res['dtype'] = dtype
            with open(new_sft_args_path, 'w', encoding='utf-8') as f:
                json.dump(res, f, ensure_ascii=False, indent=2)


def merge_lora(args: InferArguments,
               replace_if_exists=False,
               device_map: Optional[str] = None,
               **kwargs) -> Optional[str]:
    logger.info(f'replace_if_exists: {replace_if_exists}')
    assert args.ckpt_dir is not None, 'args.ckpt_dir is not specified.'
    assert args.sft_type in ('lora', 'adalora', 'longlora'), 'Only supports lora series models'
    for s in ['int4', 'int8', 'awq']:
        assert s not in args.model_type, f'{s} model is not supported'
    if args.quantization_bit != 0:
        logger.warning('It is not recommended to merge quantized models, '
                       'as this can result in performance degradation')
    ckpt_dir, ckpt_name = os.path.split(args.ckpt_dir)
    merged_lora_path = os.path.join(ckpt_dir, f'{ckpt_name}-merged')
    logger.info(f'merged_lora_path: `{merged_lora_path}`')
    if os.path.exists(merged_lora_path) and not replace_if_exists:
        logger.info(f'The weight directory for the merged LoRA already exists in {args.ckpt_dir}, '
                    'skipping the saving process. '
                    'you can pass `replace_if_exists=True` to overwrite it.')
    else:
        model, template = prepare_model_template(args, device_map=args.merge_device_map, verbose=False)
        logger.info('Merge LoRA...')
        Swift.merge_and_unload(model)
        model = model.model
        logger.info('Saving merged weights...')
        save_checkpoint(
            model,
            template.tokenizer,
            model.model_dir,
            args.ckpt_dir,
            merged_lora_path,
            save_safetensors=args.save_safetensors,
            dtype=args.dtype)
        logger.info(f'Successfully merged LoRA and saved in {merged_lora_path}.')
    logger.info("Setting args.sft_type: 'full'")
    logger.info(f'Setting args.ckpt_dir: {merged_lora_path}')
    args.sft_type = 'full'
    args.ckpt_dir = merged_lora_path
    return merged_lora_path


def prepare_model_template(args: InferArguments,
                           *,
                           device_map: Optional[str] = None,
                           verbose: bool = True,
                           automodel_class=None) -> Tuple[PreTrainedModel, Template]:

    model_kwargs = {}
    if is_torch_npu_available():
        logger.info(f'device_count: {torch.npu.device_count()}')
        if device_map is None:
            device_map = 'npu:0'
    else:
        logger.info(f'device_count: {torch.cuda.device_count()}')
        if device_map is None:
            device_map = 'auto'
    if device_map == 'auto':
        model_kwargs['low_cpu_mem_usage'] = True
    model_kwargs['device_map'] = device_map

    # Loading Model and Tokenizer
    if hasattr(args, 'quant_config'):
        model_kwargs['quantization_config'] = args.quant_config
    elif args.load_in_8bit or args.load_in_4bit:
        quantization_config = BitsAndBytesConfig(
            args.load_in_8bit,
            args.load_in_4bit,
            bnb_4bit_compute_dtype=args.bnb_4bit_compute_dtype,
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,
            bnb_4bit_use_double_quant=args.bnb_4bit_use_double_quant)
        if args.bnb_4bit_compute_dtype is None:
            quantization_config.bnb_4bit_compute_dtype = None
        logger.info(f'quantization_config: {quantization_config.__dict__}')
        model_kwargs['quantization_config'] = quantization_config
    kwargs = {}
    if args.use_flash_attn is not None:
        kwargs['use_flash_attn'] = args.use_flash_attn
    model_id_or_path = None
    if args.sft_type == 'full' and args.ckpt_dir is not None:
        model_id_or_path = args.ckpt_dir
    elif args.model_id_or_path is not None:
        model_id_or_path = args.model_id_or_path
    if automodel_class is not None:
        kwargs['automodel_class'] = automodel_class
    if args.local_repo_path:
        kwargs['local_repo_path'] = args.local_repo_path
    if args.rope_scaling:
        kwargs['rope_scaling'] = args.rope_scaling
        kwargs['max_length'] = args.max_length
    model, tokenizer = get_model_tokenizer(
        args.model_type,
        args.torch_dtype,
        model_kwargs,
        model_id_or_path=model_id_or_path,
        revision=args.model_revision,
        **kwargs)
    if verbose:
        logger.info(f'model_config: {model.config}')

    generation_config = GenerationConfig(
        max_new_tokens=args.max_new_tokens,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        do_sample=args.do_sample,
        repetition_penalty=args.repetition_penalty,
        num_beams=args.num_beams,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id)
    logger.info(f'generation_config: {generation_config}')
    set_generation_config(model, generation_config)

    if model.max_model_len is None:
        model.max_model_len = args.max_model_len
    elif args.max_model_len is not None:
        if args.max_model_len <= model.max_model_len:
            model.max_model_len = args.max_model_len
        else:
            raise ValueError('args.max_model_len exceeds the maximum max_model_len supported by the model.'
                             f'args.max_model_len: {args.max_model_len}, model.max_model_len: {model.max_model_len}')
    # Preparing LoRA
    if is_adapter(args.sft_type) and args.ckpt_dir is not None:
        if isinstance(args, DeployArguments) and args.lora_request_list is not None:
            for lora_request in args.lora_request_list:
                model = Swift.from_pretrained(
                    model, lora_request.lora_local_path, lora_request.lora_name, inference_mode=True)
        else:
            model = Swift.from_pretrained(model, args.ckpt_dir, inference_mode=True)
        model = model.to(model.dtype)

    if verbose:
        show_layers(model)
        logger.info(model)
    logger.info(get_model_info(model))
    template: Template = get_template(
        args.template_type,
        tokenizer,
        args.system,
        args.max_length,
        args.truncation_strategy,
        model=model,
        tools_prompt=args.tools_prompt)
    args.system = template.default_system
    logger.info(f'system: {args.system}')
    return model, template


def read_media_file(infer_kwargs: Dict[str, Any], infer_media_type: Literal['none', 'round', 'dialogue']) -> None:
    text = 'Input a media path or URL <<< '
    images = infer_kwargs.get('images', [])
    if infer_media_type == 'none':
        return
    if infer_media_type == 'round' or len(images) == 0:
        image = input(text)
        if len(image) > 0:
            images += [image]
    if len(images) > 0:
        infer_kwargs['images'] = images


def llm_infer(args: InferArguments) -> None:
    logger.info(f'args: {args}')
    seed_everything(args.seed)
    if args.merge_lora:
        merge_lora(args, device_map=args.merge_device_map)

    if args.infer_backend == 'vllm':
        from .utils import prepare_vllm_engine_template, inference_stream_vllm, inference_vllm
        llm_engine, template = prepare_vllm_engine_template(args)
    else:
        device_map = None
        if args.device_map_config_path is not None:
            cwd = os.getcwd()
            config_path = args.device_map_config_path if os.path.isabs(args.device_map_config_path) else os.path.join(
                cwd, args.device_map_config_path)
            with open(config_path, 'r') as json_file:
                device_map = json.load(json_file)
        if args.quant_method == 'hqq':
            from transformers import HqqConfig
            if args.hqq_dynamic_config_path is not None:
                cwd = os.getcwd()
                config_path = args.hqq_dynamic_config_path if os.path.isabs(
                    args.hqq_dynamic_config_path) else os.path.join(cwd, args.hqq_dynamic_config_path)
                with open(config_path, 'r') as json_file:
                    args.quant_config = HqqConfig(dynamic_config=json.load(json_file))
            else:
                if args.quantization_bit == 0:
                    logger.info("You haven't set the quantization_bit parameter; set it to 8.")
                    args.quantization_bit = 8
                args.quant_config = HqqConfig(nbits=args.quantization_bit, axis=args.hqq_axis)
        elif args.quant_method == 'eetq':
            from transformers import EetqConfig
            args.quant_config = EetqConfig('int8')
        model, template = prepare_model_template(args, device_map=device_map)
        if args.overwrite_generation_config:
            assert args.ckpt_dir is not None, 'args.ckpt_dir is not specified.'
            model.generation_config.save_pretrained(args.ckpt_dir)
    lora_request = None
    if args.vllm_enable_lora:
        assert len(args.lora_request_list) == 1
        lora_request = args.lora_request_list[0]
    # Inference
    result = []
    jsonl_path = None
    if args.save_result:
        result_dir = args.ckpt_dir
        if result_dir is None:
            result_dir = llm_engine.model_dir if args.infer_backend == 'vllm' else model.model_dir
        if result_dir is not None:
            result_dir = os.path.join(result_dir, 'infer_result')
            os.makedirs(result_dir, exist_ok=True)
            time = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
            jsonl_path = os.path.join(result_dir, f'{time}.jsonl')
    if args.eval_human:
        input_mode: Literal['S', 'M'] = 'S'
        logger.info('Input `exit` or `quit` to exit the conversation.')
        logger.info('Input `multi-line` to switch to multi-line input mode.')
        logger.info('Input `reset-system` to reset the system and clear the history.')
        if template.support_multi_round:
            logger.info('Input `clear` to clear the history.')
        else:
            logger.info('The current template only supports single-round dialogues.')
        history = []
        infer_kwargs = {}
        if args.infer_media_type != 'none':
            logger.info('Please enter the conversation content first, ' 'followed by the path to the multimedia file.')
        system = None
        read_system = False
        while True:
            if input_mode == 'S':
                addi_prompt = ''
                if read_system:
                    addi_prompt = '[S]'
                query = input(f'<<<{addi_prompt} ')
            else:
                addi_prompt = '[M]'
                if read_system:
                    addi_prompt = '[MS]'
                query = read_multi_line(addi_prompt)
            if query.strip().lower() in {'exit', 'quit'}:
                break
            elif query.strip().lower() == 'clear':
                history = []
                infer_kwargs = {}
                continue
            elif query.strip() == '' and not read_system:
                continue
            elif query.strip().lower() == 'reset-system':
                read_system = True
                continue
            if read_system:
                if query == '':
                    system = None
                else:
                    system = query
                read_system = False
                history = []
                infer_kwargs = {}
                continue
            if input_mode == 'S' and query.strip().lower() == 'multi-line':
                input_mode = 'M'
                logger.info('End multi-line input with `#`.')
                logger.info('Input `single-line` to switch to single-line input mode.')
                continue
            if input_mode == 'M' and query.strip().lower() == 'single-line':
                input_mode = 'S'
                continue
            if not template.support_multi_round:
                history = []
                infer_kwargs = {}

            read_media_file(infer_kwargs, args.infer_media_type)
            if args.truncation_strategy:
                infer_kwargs['truncation_strategy'] = args.truncation_strategy
            if system is None and template.use_default_system:
                system = template.default_system
            if args.infer_backend == 'vllm':
                request_list = [{'query': query, 'history': history, 'system': system}]
                if args.stream:
                    gen = inference_stream_vllm(llm_engine, template, request_list, lora_request=lora_request)
                    print_idx = 0
                    for resp_list in gen:
                        response = resp_list[0]['response']
                        new_history = resp_list[0]['history']
                        if len(response) > print_idx:
                            print(response[print_idx:], end='', flush=True)
                            print_idx = len(response)
                    print()
                else:
                    resp_list = inference_vllm(llm_engine, template, request_list, lora_request=lora_request)
                    response = resp_list[0]['response']
                    new_history = resp_list[0]['history']
                    print(response)
            else:
                if args.stop_words:
                    infer_kwargs['stop_words'] = args.stop_words
                if args.stream:
                    gen = inference_stream(model, template, query, history, system, **infer_kwargs)
                    print_idx = 0
                    for response, new_history in gen:
                        if len(response) > print_idx:
                            print(response[print_idx:], end='', flush=True)
                            print_idx = len(response)
                    print()
                else:
                    response, new_history = inference(model, template, query, history, system, **infer_kwargs)
                    print(response)
            print('-' * 50)
            obj = {
                'system': system,
                'query': query,
                'response': response,
                'history': history,
            }
            images = infer_kwargs.get('images')
            if images is not None:
                obj['images'] = images
            history = new_history
            if jsonl_path is not None:
                append_to_jsonl(jsonl_path, obj)
            result.append(obj)
    else:
        dataset_kwargs = {
            'dataset_seed': args.dataset_seed,
            'check_dataset_strategy': args.check_dataset_strategy,
            'model_name': args.model_name,
            'model_author': args.model_author
        }
        if len(args.val_dataset) > 0:
            _, val_dataset = get_dataset(args.val_dataset, 1.0, **dataset_kwargs)
        else:
            _, val_dataset = get_dataset(args.dataset, args.dataset_test_ratio, **dataset_kwargs)
        _, val_dataset = args._handle_dataset_compat(_, val_dataset)
        assert val_dataset is not None
        if args.show_dataset_sample >= 0 and val_dataset.shape[0] > args.show_dataset_sample:
            random_state = np.random.RandomState(args.dataset_seed)
            logger.info(f'show_dataset_sample: {args.show_dataset_sample}')
            val_dataset = sample_dataset(val_dataset, args.show_dataset_sample, random_state)
        logger.info(f'val_dataset: {val_dataset}')

        if args.verbose is None:
            if len(val_dataset) >= 100:
                args.verbose = False
            else:
                args.verbose = True
            logger.info(f'Setting args.verbose: {args.verbose}')
        if not args.verbose and args.stream:
            args.stream = False
            logger.info(f'Setting args.stream: {args.stream}')

        if args.infer_backend == 'vllm' and not args.stream:
            if args.verbose:
                args.verbose = False
                logger.info('Setting args.verbose: False')
            label_list = None
            if 'response' in val_dataset.features:
                label_list = val_dataset['response']
                val_dataset = val_dataset.remove_columns('response')
            request_list = []
            for data in val_dataset:
                request = {'query': data['query']}
                history = data.get('history')
                system = data.get('system')
                images = data.get('images')
                if history is None:
                    history = []
                request['history'] = history
                if system is None and template.use_default_system:
                    system = template.default_system
                request['system'] = system
                if images is not None:
                    request['images'] = images
                if args.truncation_strategy:
                    request['truncation_strategy'] = args.truncation_strategy
                request_list.append(request)
            resp_list = inference_vllm(llm_engine, template, request_list, use_tqdm=True)
            result = []
            if label_list is not None:
                for request, label in zip(request_list, label_list):
                    request['label'] = label
            for request, resp in zip(request_list, resp_list):
                obj = {
                    'system': request['system'],
                    'query': request['query'],
                    'response': resp['response'],
                    'label': request.pop('label', None),
                    'history': request['history'],
                }
                images = request.get('images')
                if images is not None:
                    obj['images'] = images
                if jsonl_path is not None:
                    append_to_jsonl(jsonl_path, obj)
                result.append(obj)
        else:
            if not args.verbose:
                val_dataset = tqdm(val_dataset)
            for data in val_dataset:
                kwargs = {'query': data['query']}
                history = data.get('history')
                system = data.get('system')
                images = data.get('images')
                tools = data.get('tools')
                if args.verbose and system is not None:
                    print(f'[SYSTEM]{system}')
                if history is None:
                    history = []
                kwargs['history'] = history
                if system is None and template.use_default_system:
                    system = template.default_system
                kwargs['system'] = system
                if images is not None:
                    kwargs['images'] = images
                if tools is not None:
                    kwargs['tools'] = tools
                if args.truncation_strategy:
                    kwargs['truncation_strategy'] = args.truncation_strategy
                if args.infer_backend == 'vllm':
                    assert args.stream is True
                    if args.verbose:
                        print(f"[QUERY]{data['query']}\n[RESPONSE]", end='')
                    gen = inference_stream_vllm(llm_engine, template, [kwargs], lora_request=lora_request)
                    print_idx = 0
                    for resp_list in gen:
                        response = resp_list[0]['response']
                        if args.verbose and len(response) > print_idx:
                            print(response[print_idx:], end='', flush=True)
                            print_idx = len(response)
                    print()
                else:
                    response, _ = inference(
                        model, template, stream=args.stream and args.verbose, verbose=args.verbose, **kwargs)
                label = data.pop('response', None)
                obj = {
                    'system': kwargs['system'],
                    'query': kwargs['query'],
                    'response': response,
                    'label': label,
                    'history': kwargs['history'],
                }
                if images is not None:
                    obj['images'] = images
                if jsonl_path is not None:
                    append_to_jsonl(jsonl_path, obj)
                result.append(obj)
                if args.verbose:
                    print()
                    print(f'[LABELS]{label}')
                    if images is not None:
                        print(f'[IMAGES]{images}')
                    print('-' * 50)
    if jsonl_path is not None:
        logger.info(f'save_result_path: {jsonl_path}')
    if not args.eval_human and args.show_dataset_sample == 10:  # is default
        logger.info('You can set `--show_dataset_sample -1` to perform inference on the entire dataset.')
    return {'result': result}


infer_main = get_main(InferArguments, llm_infer)
merge_lora_main = get_main(InferArguments, merge_lora)