vllm_infer.py 7.72 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
luopl's avatar
luopl committed
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

chenych's avatar
chenych committed
15
import gc
luopl's avatar
luopl committed
16
import json
chenych's avatar
chenych committed
17
from typing import Optional
luopl's avatar
luopl committed
18
19

import fire
chenych's avatar
chenych committed
20
from tqdm import tqdm
luopl's avatar
luopl committed
21
22
23
24
from transformers import Seq2SeqTrainingArguments

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
chenych's avatar
chenych committed
25
from llamafactory.extras.misc import get_device_count
luopl's avatar
luopl committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from llamafactory.extras.packages import is_vllm_available
from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer


if is_vllm_available():
    from vllm import LLM, SamplingParams
    from vllm.lora.request import LoRARequest


def vllm_infer(
    model_name_or_path: str,
    adapter_name_or_path: str = None,
    dataset: str = "alpaca_en_demo",
    dataset_dir: str = "data",
    template: str = "default",
    cutoff_len: int = 2048,
chenych's avatar
chenych committed
43
    max_samples: Optional[int] = None,
luopl's avatar
luopl committed
44
45
46
47
48
49
50
    vllm_config: str = "{}",
    save_name: str = "generated_predictions.jsonl",
    temperature: float = 0.95,
    top_p: float = 0.7,
    top_k: int = 50,
    max_new_tokens: int = 1024,
    repetition_penalty: float = 1.0,
chenych's avatar
chenych committed
51
    skip_special_tokens: bool = True,
chenych's avatar
chenych committed
52
53
    default_system: Optional[str] = None,
    enable_thinking: bool = True,
chenych's avatar
chenych committed
54
    seed: Optional[int] = None,
luopl's avatar
luopl committed
55
    pipeline_parallel_size: int = 1,
chenych's avatar
chenych committed
56
57
    image_max_pixels: int = 768 * 768,
    image_min_pixels: int = 32 * 32,
chenych's avatar
chenych committed
58
59
60
    video_fps: float = 2.0,
    video_maxlen: int = 128,
    batch_size: int = 1024,
luopl's avatar
luopl committed
61
):
chenych's avatar
chenych committed
62
63
    r"""Perform batch generation using vLLM engine, which supports tensor parallelism.

luopl's avatar
luopl committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
    """
    if pipeline_parallel_size > get_device_count():
        raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")

    model_args, data_args, _, generating_args = get_infer_args(
        dict(
            model_name_or_path=model_name_or_path,
            adapter_name_or_path=adapter_name_or_path,
            dataset=dataset,
            dataset_dir=dataset_dir,
            template=template,
            cutoff_len=cutoff_len,
            max_samples=max_samples,
            preprocessing_num_workers=16,
chenych's avatar
chenych committed
79
80
            default_system=default_system,
            enable_thinking=enable_thinking,
luopl's avatar
luopl committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            vllm_config=vllm_config,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
        )
    )

    training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir")
    tokenizer_module = load_tokenizer(model_args)
    tokenizer = tokenizer_module["tokenizer"]
    template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
    template_obj.mm_plugin.expand_mm_tokens = False  # for vllm generate

chenych's avatar
chenych committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    engine_args = {
        "model": model_args.model_name_or_path,
        "trust_remote_code": True,
        "dtype": model_args.infer_dtype,
        "max_model_len": cutoff_len + max_new_tokens,
        "tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
        "pipeline_parallel_size": pipeline_parallel_size,
        "disable_log_stats": True,
        "enable_lora": model_args.adapter_name_or_path is not None,
    }
    if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
        engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}

    if isinstance(model_args.vllm_config, dict):
        engine_args.update(model_args.vllm_config)

    llm = LLM(**engine_args)

    # load datasets
    dataset_module = get_dataset(template_obj, model_args, data_args, training_args, "ppo", **tokenizer_module)
    train_dataset = dataset_module["train_dataset"]
luopl's avatar
luopl committed
117
118
119
120
121

    sampling_params = SamplingParams(
        repetition_penalty=generating_args.repetition_penalty or 1.0,  # repetition_penalty must > 0
        temperature=generating_args.temperature,
        top_p=generating_args.top_p or 1.0,  # top_p must > 0
chenych's avatar
chenych committed
122
        top_k=generating_args.top_k or -1,  # top_k must > 0
luopl's avatar
luopl committed
123
124
        stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
        max_tokens=generating_args.max_new_tokens,
chenych's avatar
chenych committed
125
        skip_special_tokens=skip_special_tokens,
chenych's avatar
chenych committed
126
        seed=seed,
luopl's avatar
luopl committed
127
128
129
130
131
132
    )
    if model_args.adapter_name_or_path is not None:
        lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
    else:
        lora_request = None

chenych's avatar
chenych committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    # Store all results in these lists
    all_prompts, all_preds, all_labels = [], [], []

    # Add batch process to avoid the issue of too many files opened
    for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
        vllm_inputs, prompts, labels = [], [], []
        batch = train_dataset[i : min(i + batch_size, len(train_dataset))]

        for j in range(len(batch["input_ids"])):
            if batch["images"][j] is not None:
                image = batch["images"][j]
                multi_modal_data = {
                    "image": template_obj.mm_plugin._regularize_images(
                        image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
                    )["images"]
                }
            elif batch["videos"][j] is not None:
                video = batch["videos"][j]
                multi_modal_data = {
                    "video": template_obj.mm_plugin._regularize_videos(
                        video,
                        image_max_pixels=image_max_pixels,
                        image_min_pixels=image_min_pixels,
                        video_fps=video_fps,
                        video_maxlen=video_maxlen,
                    )["videos"]
                }
            elif batch["audios"][j] is not None:
                audio = batch["audios"][j]
                audio_data = template_obj.mm_plugin._regularize_audios(
                    audio,
                    sampling_rate=16000,
                )
                multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
            else:
                multi_modal_data = None

            vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
            prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
            labels.append(
                tokenizer.decode(
                    list(filter(lambda x: x != IGNORE_INDEX, batch["labels"][j])),
                    skip_special_tokens=skip_special_tokens,
                )
            )
luopl's avatar
luopl committed
178

chenych's avatar
chenych committed
179
180
181
182
183
184
185
186
        results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
        preds = [result.outputs[0].text for result in results]

        # Accumulate results
        all_prompts.extend(prompts)
        all_preds.extend(preds)
        all_labels.extend(labels)
        gc.collect()
luopl's avatar
luopl committed
187

chenych's avatar
chenych committed
188
    # Write all results at once outside the loop
luopl's avatar
luopl committed
189
    with open(save_name, "w", encoding="utf-8") as f:
chenych's avatar
chenych committed
190
        for text, pred, label in zip(all_prompts, all_preds, all_labels):
luopl's avatar
luopl committed
191
192
193
            f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")

    print("*" * 70)
chenych's avatar
chenych committed
194
    print(f"{len(all_prompts)} total generated results have been saved at {save_name}.")
luopl's avatar
luopl committed
195
196
197
198
199
    print("*" * 70)


if __name__ == "__main__":
    fire.Fire(vllm_infer)