omni_llm.py 10.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
from collections.abc import Callable
from typing import Any

import cloudpickle
from pydantic import ValidationError
from tqdm import tqdm

# External library imports (vLLM)
from vllm.config import CompilationConfig, StructuredOutputsConfig, is_init_field
from vllm.entrypoints.llm import LLM
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.usage.usage_lib import UsageContext
from vllm.utils.counter import Counter
from vllm.v1.engine.llm_engine import LLMEngine

from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors

# Internal imports (our code)
from vllm_omni.engine.arg_utils import OmniEngineArgs
from vllm_omni.engine.input_processor import OmniInputProcessor
from vllm_omni.engine.output_processor import MultimodalOutputProcessor
from vllm_omni.entrypoints.utils import (
    load_stage_configs_from_model,
    load_stage_configs_from_yaml,
    resolve_model_config_path,
)

logger = init_logger(__name__)


class OmniLLM(LLM):
    """Main entry point for vLLM-Omni inference.

    This class extends the base vLLM LLM class with omni-specific
    processors for handling multimodal inputs and outputs. It provides
    configuration loading for multi-stage pipelines, while stage management
    is handled by the Omni class.

    Args:
        model: Model name or path to load
        stage_configs_path: Optional path to YAML file containing stage
            configurations. If None, configurations are loaded from the model.
        log_stats: Whether to enable statistics logging
        compilation_config: Optional compilation configuration. Can be an
            integer (compilation level), dict, or CompilationConfig instance.
        hf_overrides: Optional HuggingFace model configuration overrides
        structured_outputs_config: Optional structured outputs configuration.
            Can be a dict or StructuredOutputsConfig instance.
        init_sleep_seconds: Number of seconds to sleep between starting
            each stage process during initialization (used by Omni class)
        shm_threshold_bytes: Threshold in bytes for using shared memory
            for IPC. Objects larger than this threshold will use shared memory.
        batch_timeout: Timeout in seconds for batching requests within a stage
        init_timeout: Timeout in seconds for waiting for all stages to initialize
        **kwargs: Additional keyword arguments passed to the base LLM class
            and engine

    Example:
        >>> llm = OmniLLM(model="Qwen/Qwen2.5-Omni-7B")
        >>> # Stage management is handled by Omni class
    """

    def __init__(
        self,
        model: str,
        stage_configs_path: str | None = None,
        log_stats: bool = False,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        hf_overrides: dict[str, Any] | None = None,
        structured_outputs_config: dict[str, Any] | StructuredOutputsConfig | None = None,
        init_sleep_seconds: int = 20,
        shm_threshold_bytes: int = 65536,
        batch_timeout: int = 10,
        init_timeout: int = 300,
        **kwargs: Any,
    ):
        """LLM constructor with omni-specific configuration loading."""
        # Store stage management parameters (used by Omni class)
        self.worker_backend = kwargs.get("worker_backend", "multi_process")
        self.ray_address = kwargs.get("ray_address", None)
        self.batch_timeout = batch_timeout
        self._enable_stats: bool = bool(log_stats)

        # Load stage configurations
        if stage_configs_path is None:
            self.config_path = resolve_model_config_path(model)
            self.stage_configs = load_stage_configs_from_model(model)
        else:
            self.config_path = stage_configs_path
            self.stage_configs = load_stage_configs_from_yaml(stage_configs_path)

        # Initialize connectors
        self.omni_transfer_config, self.connectors = initialize_orchestrator_connectors(
            self.config_path, worker_backend=self.worker_backend, shm_threshold_bytes=shm_threshold_bytes
        )

        # Initialize LLM engine
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True

        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

        if "kv_transfer_config" in kwargs and isinstance(kwargs["kv_transfer_config"], dict):
            from vllm.config.kv_transfer import KVTransferConfig

            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict,
                    e,
                )
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e

        # Extract omni_kv_config from kwargs if present (injected by Omni)
        omni_kv_config = kwargs.pop("omni_kv_config", None)

        if compilation_config is not None:
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(level=compilation_config)
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
                    **{k: v for k, v in compilation_config.items() if is_init_field(CompilationConfig, k)}
                )
            else:
                compilation_config_instance = compilation_config
        else:
            compilation_config_instance = CompilationConfig()

        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{k: v for k, v in structured_outputs_config.items() if is_init_field(StructuredOutputsConfig, k)}
                )
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

        engine_args = OmniEngineArgs(
            model=model,
            compilation_config=compilation_config_instance,
            structured_outputs_config=structured_outputs_instance,
            omni_kv_config=omni_kv_config,
            **kwargs,
        )

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.llm_engine.output_processor = MultimodalOutputProcessor(
            tokenizer=self.llm_engine.tokenizer,
            log_stats=self.llm_engine.log_stats,
            engine_core_output_type=engine_args.engine_output_type,
        )
        self.llm_engine.input_processor = OmniInputProcessor(vllm_config=self.llm_engine.vllm_config)
        self.engine_class = type(self.llm_engine)

        self.request_counter = Counter()
        self.default_sampling_params: dict[str, Any] | None = None

        supported_tasks = self.llm_engine.get_supported_tasks()  # type: ignore

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config, io_processor_plugin)
        self.model_config = self.llm_engine.model_config
        self.input_processor = self.llm_engine.input_processor

    def close(self) -> None:
        """Close resources.

        Note: Stage management is now handled by Omni class.
        This method closes the LLM engine but not stages.
        """
        # Close the LLM engine if it exists
        if hasattr(self, "llm_engine") and self.llm_engine is not None:
            if hasattr(self.llm_engine, "shutdown"):
                self.llm_engine.shutdown()

    def __del__(self) -> None:  # best-effort
        try:
            self.close()
        except Exception as e:
            logger.debug("[Orchestrator] __del__ close() raised: %s", e, exc_info=True)

    def _run_engine(self, *, use_tqdm: bool | Callable[..., tqdm] = True) -> list[RequestOutput | PoolingRequestOutput]:
        # Initialize tqdm.
        if use_tqdm:
            num_requests = self.llm_engine.get_num_unfinished_requests()
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
            )

        # Run the engine.
        outputs: list[RequestOutput | PoolingRequestOutput] = []
        total_in_toks = 0
        total_out_toks = 0
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
            for output in step_outputs:
                if output.finished:
                    outputs.append(output)
                    if use_tqdm:
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
                            n = len(output.outputs)
                            assert output.prompt_token_ids is not None
                            total_in_toks += len(output.prompt_token_ids) * n
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(len(stp.token_ids) for stp in output.outputs)
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
                            pbar.postfix = f"est. speed input: {in_spd:.2f} toks/s, output: {out_spd:.2f} toks/s"
                            pbar.update(n)
                        else:
                            pbar.update(1)
                        if pbar.n == num_requests:
                            pbar.refresh()

        if use_tqdm:
            pbar.close()
        # Sort the outputs by the int part of request ID which is in format of 'int-uuid'.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
        return sorted(outputs, key=lambda x: int(x.request_id.split("-")[0]))